提交 730f1d87 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix a problem in clone preventing graphs from having a depth > 1000.

clone() does not use recursion anymore.
上级 29adf634
...@@ -349,7 +349,6 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True): ...@@ -349,7 +349,6 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
""" """
d = {} d = {}
for input in i: for input in i:
if copy_inputs_and_orphans: if copy_inputs_and_orphans:
cpy = input.clone() cpy = input.clone()
...@@ -359,29 +358,57 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True): ...@@ -359,29 +358,57 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
else: else:
d[input] = input d[input] = input
def clone_helper(result):
if result in d:
return d[result]
node = result.owner
if node is None: # result is an orphan
if copy_inputs_and_orphans:
cpy = result.clone()
d[result] = cpy
else:
d[result] = result
return d[result]
else:
new_node = node.clone_with_new_inputs([clone_helper(input) for input in node.inputs])
d[node] = new_node
for output, new_output in zip(node.outputs, new_node.outputs):
d[output] = new_output
return d[result]
for output in o:
clone_helper(output) for apply in io_toposort(i, o):
for input in apply.inputs:
if input not in d:
if copy_inputs_and_orphans:
cpy = input.clone()
d[input] = cpy
else:
d[input] = input
new_apply = apply.clone_with_new_inputs([d[i] for i in apply.inputs])
d[apply] = new_apply
for output, new_output in zip(apply.outputs, new_apply.outputs):
d[output] = new_output
return d return d
## Previous version
# for input in i:
# if copy_inputs_and_orphans:
# cpy = input.clone()
# cpy.owner = None
# cpy.index = None
# d[input] = cpy
# else:
# d[input] = input
#
# def clone_helper(result):
# if result in d:
# return d[result]
# node = result.owner
# if node is None: # result is an orphan
# if copy_inputs_and_orphans:
# cpy = result.clone()
# d[result] = cpy
# else:
# d[result] = result
# return d[result]
# else:
# new_node = node.clone_with_new_inputs([clone_helper(input) for input in node.inputs])
# d[node] = new_node
# for output, new_output in zip(node.outputs, new_node.outputs):
# d[output] = new_output
# return d[result]
#
# for output in o:
# clone_helper(output)
#
# return d
# def clone_with_new_inputs(i, o, new_i): # def clone_with_new_inputs(i, o, new_i):
# equiv = clone_with_new_inputs_get_equiv(i, o, new_i) # equiv = clone_with_new_inputs_get_equiv(i, o, new_i)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论