提交 faee95e7 authored 作者: James Bergstra's avatar James Bergstra

fixed ticket #326

上级 0696cbd2
...@@ -207,7 +207,7 @@ def _optcheck_env(input_specs, output_specs, accept_inplace = False): ...@@ -207,7 +207,7 @@ def _optcheck_env(input_specs, output_specs, accept_inplace = False):
env.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(env, 'destroyers') and env.destroyers(input))))) env.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(env, 'destroyers') and env.destroyers(input)))))
return env, map(SymbolicOutput, updates), equivalence_tracker return env, map(SymbolicOutput, updates), equivalence_tracker
def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes): def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, clobber_dr_vals=True):
"""Raise BadDestroyMap if necessary, update dr_vals""" """Raise BadDestroyMap if necessary, update dr_vals"""
destroyed_idx_list = [] destroyed_idx_list = []
destroy_map = getattr(node.op, 'destroy_map', {}) destroy_map = getattr(node.op, 'destroy_map', {})
...@@ -225,7 +225,9 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes): ...@@ -225,7 +225,9 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes):
if dr_vals.get(r, (0, node))[1] is not node: if dr_vals.get(r, (0, node))[1] is not node:
# bad: there should only be one active node that destroys any result # bad: there should only be one active node that destroys any result
raise Exception('failure in topological ordering') raise Exception('failure in topological ordering')
if clobber_dr_vals:
dr_vals[r] = (storage_map[r][0], node) #no copy, this is the last use of this variable dr_vals[r] = (storage_map[r][0], node) #no copy, this is the last use of this variable
storage_map[r][0] = None #make sure that dr_vals[r] doens't get used again
else: else:
raise BadDestroyMap(node, r_idx, r_vals[r], storage_map[r][0]) raise BadDestroyMap(node, r_idx, r_vals[r], storage_map[r][0])
...@@ -572,10 +574,10 @@ class _Linker(gof.link.LocalLinker): ...@@ -572,10 +574,10 @@ class _Linker(gof.link.LocalLinker):
thunk.inputs = node_input_storage thunk.inputs = node_input_storage
thunk.outputs = node_output_storage thunk.outputs = node_output_storage
thunks_c.append(thunk) thunks_c.append(thunk)
except (NotImplementedError, utils.AbstractFunctionError): except (NotImplementedError, utils.AbstractFunctionError):
thunks_c.append(None) thunks_c.append(None)
p = node.op.perform p = node.op.perform
thunk = (lambda p = p, i = node_input_storage, o = node_output_storage, n = thunk = (lambda p = p, i = node_input_storage, o = node_output_storage, n =
node: p(n, [x[0] for x in i], o)) node: p(n, [x[0] for x in i], o))
...@@ -640,7 +642,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -640,7 +642,8 @@ class _Linker(gof.link.LocalLinker):
thunk_py() thunk_py()
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set) _check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=True)
#retrieve each output from the storage_map #retrieve each output from the storage_map
for r in node.outputs: for r in node.outputs:
...@@ -660,7 +663,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -660,7 +663,8 @@ class _Linker(gof.link.LocalLinker):
thunk_c() thunk_c()
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set) _check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=False)
for r in node.outputs: for r in node.outputs:
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论