提交 5d1ecb0b authored 作者: nouiz's avatar nouiz

Merge pull request #685 from jaberg/clone_get_equiv

Clone get equiv
......@@ -575,54 +575,67 @@ def clone(i, o, copy_inputs = True):
return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
""" WRITEME
def clone_get_equiv(inputs, outputs,
copy_inputs_and_orphans=True,
memo=None):
"""
Return a dictionary that maps from Variable and Apply nodes in the
original graph to a new node (a clone) in a new graph.
:type i: list
:param i: input L{Variable}s
:type o: list
:param o: output L{Variable}s
:type copy_inputs_and_orphans: bool
:param copy_inputs_and_orphans:
if True, the inputs and the orphans will be replaced in the cloned graph by copies
available in the equiv dictionary returned by the function (copy_inputs_and_orphans
defaults to True)
This function works by recursively cloning inputs... rebuilding a directed
graph from the bottom (inputs) up to eventually building new outputs.
Parameters
----------
inputs: a list of Variables
outputs: a list of Variables
copy_inputs_and_orphans: bool
True means to create the cloned graph from new input and constant
nodes (the bottom of a feed-upward graph),
False means to clone a graph that is rooted at the original input
nodes.
memo: None or dict
Optionally start with a partly-filled dictionary for the return value.
If a dictionary is passed, this function will work in-place on that
dictionary and return it.
:rtype: a dictionary
:return:
equiv mapping each L{Variable} and L{Op} in the graph delimited by i and o to a copy
(akin to deepcopy's memo).
"""
d = {}
for input in i:
if memo is None:
memo = {}
# clone the inputs if necessary
for input in inputs:
if copy_inputs_and_orphans:
cpy = input.clone()
cpy.owner = None
cpy.index = None
d[input] = cpy
memo.setdefault(input, cpy)
else:
d[input] = input
memo.setdefault(input, input)
for apply in io_toposort(i, o):
# go through the inputs -> outputs graph cloning as we go
for apply in io_toposort(inputs, outputs):
for input in apply.inputs:
if input not in d:
if input not in memo:
if copy_inputs_and_orphans:
cpy = input.clone()
d[input] = cpy
memo[input] = cpy
else:
d[input] = input
memo[input] = input
new_apply = apply.clone_with_new_inputs([d[i] for i in apply.inputs])
d[apply] = new_apply
new_apply = apply.clone_with_new_inputs([memo[i] for i in apply.inputs])
memo.setdefault(apply, new_apply)
for output, new_output in zip(apply.outputs, new_apply.outputs):
d[output] = new_output
memo.setdefault(output, new_output)
for output in o:
if output not in d:
d[output] = output.clone()
# finish up by cloning any remaining outputs (it can happen)
for output in outputs:
if output not in memo:
memo[output] = output.clone()
return memo
return d
def general_toposort(r_out, deps, debug_print = False):
"""WRITEME
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论