提交 ab3f5eaf authored 作者: James Bergstra's avatar James Bergstra

removed clone_with_equiv because it seems not to be used anywhere

上级 b8e4b392
...@@ -376,127 +376,6 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True): ...@@ -376,127 +376,6 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
return d return d
## Previous version
# for input in i:
# if copy_inputs_and_orphans:
# cpy = input.clone()
# cpy.owner = None
# cpy.index = None
# d[input] = cpy
# else:
# d[input] = input
#
# def clone_helper(result):
# if result in d:
# return d[result]
# node = result.owner
# if node is None: # result is an orphan
# if copy_inputs_and_orphans:
# cpy = result.clone()
# d[result] = cpy
# else:
# d[result] = result
# return d[result]
# else:
# new_node = node.clone_with_new_inputs([clone_helper(input) for input in node.inputs])
# d[node] = new_node
# for output, new_output in zip(node.outputs, new_node.outputs):
# d[output] = new_output
# return d[result]
#
# for output in o:
# clone_helper(output)
#
# return d
# def clone_with_new_inputs(i, o, new_i):
# equiv = clone_with_new_inputs_get_equiv(i, o, new_i)
# return [equiv[input] for input in i], [equiv[output] for output in o]
# def clone_with_new_inputs_get_equiv(i, o, new_i, copy_orphans = True):
# # note: this does not exactly mirror Apply.clone_with_new_inputs
# # here it is possible to give different types to new_i and then
# # make_node is called on the ops instead of clone_with_new_inputs
# # whenever the type is different.
# d = {}
# for input, new_input in zip(i, new_i):
# d[input] = new_input
# def clone_helper(result):
# if result in d:
# return d[result]
# node = result.owner
# if node is None: # result is an orphan
# if copy_orphans:
# cpy = result.clone()
# d[result] = cpy
# else:
# d[result] = result
# return d[result]
# else:
# cloned_inputs = [clone_helper(input) for input in node.inputs]
# if any(input != cloned_input for input, cloned_input in zip(node.inputs, cloned_inputs)):
# new_node = node.op.make_node(*cloned_inputs)
# else:
# new_node = node.clone_with_new_inputs(cloned_inputs)
# d[node] = new_node
# for output, new_output in zip(node.outputs, new_node.outputs):
# d[output] = new_output
# return d[result]
# for output in o:
# clone_helper(output)
# return d
def clone_with_equiv(i, o, d, missing_input_policy = 'fail', orphan_policy = 'copy'):
def clone_helper(result):
if result in d:
return d[result]
node = result.owner
if node is None: # result is an input or an orphan not in d
if isinstance(result, Value):
if orphan_policy == 'copy':
d[result] = copy(result)
elif orphan_policy == 'keep':
d[result] = result
else:
raise ValueError("unknown orphan_policy: '%s'" % orphan_policy)
else:
if missing_input_policy == 'fail':
raise ValueError("missing input: %s" % result)
elif missing_input_policy == 'keep':
d[result] = result
else:
raise ValueError("unknown missing_input_policy: '%s'" % missing_input_policy)
return d[result]
else:
cloned_inputs = [clone_helper(input) for input in node.inputs]
if all(input is cloned_input for input, cloned_input in zip(node.inputs, cloned_inputs)):
new_node = node
else:
new_node = node.clone_with_new_inputs(cloned_inputs, strict = False)
# if any(input != cloned_input for input, cloned_input in zip(node.inputs, cloned_inputs)):
# new_node = node.op.make_node(*cloned_inputs)
# else:
# new_node = node.clone_with_new_inputs(cloned_inputs)
d[node] = new_node
for output, new_output in zip(node.outputs, new_node.outputs):
d[output] = new_output
return d[result]
for output in o:
clone_helper(output)
return [d[input] for input in i], [d[output] for output in o]
def general_toposort(r_out, deps, debug_print = False): def general_toposort(r_out, deps, debug_print = False):
""" """
@note: deps(i) should behave like a pure function (no funny business with @note: deps(i) should behave like a pure function (no funny business with
...@@ -561,8 +440,6 @@ def io_toposort(i, o, orderings = {}): ...@@ -561,8 +440,6 @@ def io_toposort(i, o, orderings = {}):
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
default_leaf_formatter = str default_leaf_formatter = str
default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op, default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op,
", ".join(argstrings)) ", ".join(argstrings))
...@@ -667,3 +544,4 @@ def view_roots(r): ...@@ -667,3 +544,4 @@ def view_roots(r):
else: else:
return [r] return [r]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论