提交 5a42b909 authored 作者: carriepl's avatar carriepl

Convert double recursion to simple recursion

上级 f175d4b8
......@@ -83,7 +83,17 @@ def rebuild_collect_shared(outputs,
if v in clone_d:
return clone_d[v]
if v.owner:
clone_a(v.owner, copy_inputs_over)
owner = v.owner
if owner not in clone_d:
for i in owner.inputs:
clone_v_get_shared_updates(i, copy_inputs_over)
clone_d[owner] = owner.clone_with_new_inputs(
[clone_d[i] for i in owner.inputs],
strict=rebuild_strict)
for old_o, new_o in zip(owner.outputs, clone_d[owner].outputs):
clone_d.setdefault(old_o, new_o)
return clone_d.setdefault(v, v)
elif isinstance(v, SharedVariable):
if v not in shared_inputs:
......@@ -114,25 +124,6 @@ def rebuild_collect_shared(outputs,
else:
return clone_d.setdefault(v, v)
def clone_a(a, copy_inputs_over):
"""
Clones a variable and its inputs recursively until all are in
clone_d. It occures with clone_v_get_shared_updates.
"""
if a is None:
return None
if a not in clone_d:
for i in a.inputs:
clone_v_get_shared_updates(i, copy_inputs_over)
clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in
a.inputs],
strict=rebuild_strict)
for old_o, new_o in zip(a.outputs, clone_d[a].outputs):
clone_d.setdefault(old_o, new_o)
return clone_d[a]
# intialize the clone_d mapping with the replace dictionary
if replace is None:
replace = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论