提交 e84346eb authored 作者: Frederic Bastien's avatar Frederic Bastien

Allow theano.clone() to specify or not to copy orphans. Default not changed.

上级 ca1ebe71
...@@ -799,9 +799,8 @@ def orphans(i, o): ...@@ -799,9 +799,8 @@ def orphans(i, o):
return variables_and_orphans(i, o)[1] return variables_and_orphans(i, o)[1]
def clone(i, o, copy_inputs=True): def clone(i, o, copy_inputs=True, copy_orphans=None):
""" """Copies the subgraph contained between i and o.
Copies the subgraph contained between i and o.
Parameters Parameters
---------- ----------
...@@ -811,18 +810,32 @@ def clone(i, o, copy_inputs=True): ...@@ -811,18 +810,32 @@ def clone(i, o, copy_inputs=True):
Output Variables. Output Variables.
copy_inputs : bool copy_inputs : bool
If True, the inputs will be copied (defaults to True). If True, the inputs will be copied (defaults to True).
copy_orphans:
When None, use the copy_inputs value,
When True, new orphans nodes are created.
When False, original orphans nodes are reused in the new graph.
Returns Returns
------- -------
object object
The inputs and outputs of that copy. The inputs and outputs of that copy.
Note
----
A constant, if in the ``i`` list is not an orpha. So it will be
copied depending of the ``copy_inputs`` parameter. Otherwise it
will be copied depending of the ``copy_orphans`` parameter.
""" """
equiv = clone_get_equiv(i, o, copy_inputs) if copy_orphans is None:
copy_orphans = copy_inputs
equiv = clone_get_equiv(i, o, copy_inputs, copy_orphans)
return [equiv[input] for input in i], [equiv[output] for output in o] return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True, memo=None): def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True,
memo=None):
""" """
Return a dictionary that maps from Variable and Apply nodes in the Return a dictionary that maps from Variable and Apply nodes in the
original graph to a new node (a clone) in a new graph. original graph to a new node (a clone) in a new graph.
......
...@@ -156,6 +156,26 @@ class TestClone(X): ...@@ -156,6 +156,26 @@ class TestClone(X):
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"] assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"]
assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(R1, R2), R5)"] assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
def test_constant(self):
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5)
node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
_, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner
new_node.inputs = MyVariable(7), MyVariable(8)
c1 = tensor.constant(1.5)
i, o = clone([c1], [c1])
assert i[0] is not c1 and o[0] is not c1
i, o = clone([c1], [c1], False)
assert i[0] is c1 and o[0] is c1
i, o = clone([c1], [c1], True, False)
assert i[0] is not c1 and o[0] is not c1
i, o = clone([c1], [c1], False, True)
assert i[0] is c1 and o[0] is c1
############ ############
# toposort # # toposort #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论