提交 faee95e7 authored 作者: James Bergstra's avatar James Bergstra

fixed ticket #326

上级 0696cbd2
......@@ -207,7 +207,7 @@ def _optcheck_env(input_specs, output_specs, accept_inplace = False):
env.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(env, 'destroyers') and env.destroyers(input)))))
return env, map(SymbolicOutput, updates), equivalence_tracker
def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes):
def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, clobber_dr_vals=True):
"""Raise BadDestroyMap if necessary, update dr_vals"""
destroyed_idx_list = []
destroy_map = getattr(node.op, 'destroy_map', {})
......@@ -225,7 +225,9 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes):
if dr_vals.get(r, (0, node))[1] is not node:
# bad: there should only be one active node that destroys any result
raise Exception('failure in topological ordering')
dr_vals[r] = (storage_map[r][0], node) #no copy, this is the last use of this variable
if clobber_dr_vals:
dr_vals[r] = (storage_map[r][0], node) #no copy, this is the last use of this variable
storage_map[r][0] = None #make sure that dr_vals[r] doens't get used again
else:
raise BadDestroyMap(node, r_idx, r_vals[r], storage_map[r][0])
......@@ -572,10 +574,10 @@ class _Linker(gof.link.LocalLinker):
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunks_c.append(thunk)
except (NotImplementedError, utils.AbstractFunctionError):
thunks_c.append(None)
p = node.op.perform
thunk = (lambda p = p, i = node_input_storage, o = node_output_storage, n =
node: p(n, [x[0] for x in i], o))
......@@ -640,7 +642,8 @@ class _Linker(gof.link.LocalLinker):
thunk_py()
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set)
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=True)
#retrieve each output from the storage_map
for r in node.outputs:
......@@ -660,7 +663,8 @@ class _Linker(gof.link.LocalLinker):
thunk_c()
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set)
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=False)
for r in node.outputs:
if not r.type.is_valid_value(storage_map[r][0]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论