提交 f25541bc authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed graph.clone_equiv to copy orphans as well as inputs and added Op.clone_with_new_inputs

上级 fe80ad1a
...@@ -135,13 +135,14 @@ def clone(i, o, copy_inputs = False): ...@@ -135,13 +135,14 @@ def clone(i, o, copy_inputs = False):
return [equiv[output] for output in o] return [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs = False): def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
""" """
i -> list of input Results i -> list of input Results
o -> list of output Results o -> list of output Results
copy_inputs -> if True, the inputs will be replaced in the cloned copy_inputs_and_orphans -> if True, the inputs and the orphans
graph by copies available in the equiv dictionary will be replaced in the cloned graph by copies available in
returned by the function (copy_inputs defaults to False) the equiv dictionary returned by the function (copy_inputs
defaults to False)
Returns equiv a dictionary mapping each result and op in the Returns equiv a dictionary mapping each result and op in the
graph delimited by i and o to a copy (akin to deepcopy's memo). graph delimited by i and o to a copy (akin to deepcopy's memo).
...@@ -150,8 +151,8 @@ def clone_get_equiv(i, o, copy_inputs = False): ...@@ -150,8 +151,8 @@ def clone_get_equiv(i, o, copy_inputs = False):
d = {} d = {}
for input in i: for input in i:
if copy_inputs: if copy_inputs_and_orphans:
d[input] = copy(input) d[input] = input.clone(True)
else: else:
d[input] = input d[input] = input
...@@ -160,9 +161,13 @@ def clone_get_equiv(i, o, copy_inputs = False): ...@@ -160,9 +161,13 @@ def clone_get_equiv(i, o, copy_inputs = False):
return d[result] return d[result]
op = result.owner op = result.owner
if not op: if not op:
return result if copy_inputs_and_orphans:
d[result] = result.clone(True)
else:
d[result] = result
return d[result]
else: else:
new_op = op.__class__(*[clone_helper(input) for input in op.inputs]) new_op = op.clone_with_new_inputs(*[clone_helper(input) for input in op.inputs])
d[op] = new_op d[op] = new_op
for output, new_output in zip(op.outputs, new_op.outputs): for output, new_output in zip(op.outputs, new_op.outputs):
d[output] = new_output d[output] = new_output
......
...@@ -88,12 +88,21 @@ class Op(object): ...@@ -88,12 +88,21 @@ class Op(object):
""" """
Shallow copy of this Op. The inputs are the exact same, but Shallow copy of this Op. The inputs are the exact same, but
the outputs are recreated because of the one-owner-per-result the outputs are recreated because of the one-owner-per-result
policy. policy. The default behavior is to call the constructor on
this Op's inputs.
This implementation permits a bottom-up copy of an entire graph. To do a bottom-up copy of a graph, use clone_with_new_inputs.
""" """
return self.__class__(*self.inputs) return self.__class__(*self.inputs)
def clone_with_new_inputs(self, *new_inputs):
"""
Returns a clone of this Op that takes different inputs. The
default behavior is to call the constructor on the new inputs,
but if your Op has additional options or a different constructor
you might want to override this.
"""
return self.__class__(*new_inputs)
# #
# String representation # String representation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论