提交 56501f9d authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6059 from rizar/always_copy_orphas

separate copying inputs and copying orphans
......@@ -799,9 +799,8 @@ def orphans(i, o):
return variables_and_orphans(i, o)[1]
def clone(i, o, copy_inputs=True):
"""
Copies the subgraph contained between i and o.
def clone(i, o, copy_inputs=True, copy_orphans=None):
"""Copies the subgraph contained between i and o.
Parameters
----------
......@@ -811,18 +810,32 @@ def clone(i, o, copy_inputs=True):
Output Variables.
copy_inputs : bool
If True, the inputs will be copied (defaults to True).
copy_orphans:
When None, use the copy_inputs value,
When True, new orphans nodes are created.
When False, original orphans nodes are reused in the new graph.
Returns
-------
object
The inputs and outputs of that copy.
Note
----
A constant, if in the ``i`` list is not an orpha. So it will be
copied depending of the ``copy_inputs`` parameter. Otherwise it
will be copied depending of the ``copy_orphans`` parameter.
"""
equiv = clone_get_equiv(i, o, copy_inputs)
if copy_orphans is None:
copy_orphans = copy_inputs
equiv = clone_get_equiv(i, o, copy_inputs, copy_orphans)
return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True,
memo=None):
"""
Return a dictionary that maps from Variable and Apply nodes in the
original graph to a new node (a clone) in a new graph.
......@@ -834,11 +847,14 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
----------
inputs : a list of Variables
outputs : a list of Variables
copy_inputs_and_orphans : bool
True means to create the cloned graph from new input and constant
copy_inputs : bool
True means to create the cloned graph from new input
nodes (the bottom of a feed-upward graph).
False means to clone a graph that is rooted at the original input
nodes.
copy_orphans:
When True, new constant nodes are created. When False, original
constant nodes are reused in the new graph.
memo : None or dict
Optionally start with a partly-filled dictionary for the return value.
If a dictionary is passed, this function will work in-place on that
......@@ -850,7 +866,7 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
# clone the inputs if necessary
for input in inputs:
if copy_inputs_and_orphans:
if copy_inputs:
cpy = input.clone()
cpy.owner = None
cpy.index = None
......@@ -862,7 +878,7 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
for apply in io_toposort(inputs, outputs):
for input in apply.inputs:
if input not in memo:
if copy_inputs_and_orphans:
if copy_orphans:
cpy = input.clone()
memo[input] = cpy
else:
......
......@@ -156,6 +156,26 @@ class TestClone(X):
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"]
assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
def test_constant(self):
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5)
node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
_, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner
new_node.inputs = MyVariable(7), MyVariable(8)
c1 = tensor.constant(1.5)
i, o = clone([c1], [c1])
assert i[0] is not c1 and o[0] is not c1
i, o = clone([c1], [c1], False)
assert i[0] is c1 and o[0] is c1
i, o = clone([c1], [c1], True, False)
assert i[0] is not c1 and o[0] is not c1
i, o = clone([c1], [c1], False, True)
assert i[0] is c1 and o[0] is c1
############
# toposort #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论