提交 9c765de2 authored 作者: James Bergstra's avatar James Bergstra

docs and comments for clone_get_equiv

上级 7a71371c
...@@ -575,29 +575,37 @@ def clone(i, o, copy_inputs = True): ...@@ -575,29 +575,37 @@ def clone(i, o, copy_inputs = True):
return [equiv[input] for input in i], [equiv[output] for output in o] return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs_and_orphans=True, memo=None): def clone_get_equiv(inputs, outputs,
""" WRITEME copy_inputs_and_orphans=True,
memo=None):
"""
Return a dictionary that maps from Variable and Apply nodes in the
original graph to a new node (a clone) in a new graph.
This function works by recursively cloning inputs... rebuilding a directed
graph from the bottom (inputs) up to eventually building new outputs.
Parameters
----------
inputs: a list of Variables
outputs: a list of Variables
copy_inputs_and_orphans: bool
True means to create the cloned graph from new input and constant
nodes (the bottom of a feed-upward graph),
False means to clone a graph that is rooted at the original input
nodes.
memo: None or dict
Optionally start with a partly-filled dictionary for the return value.
If a dictionary is passed, this function will work in-place on that
dictionary and return it.
:type i: list
:param i: input L{Variable}s
:type o: list
:param o: output L{Variable}s
:type copy_inputs_and_orphans: bool
:param copy_inputs_and_orphans:
if True, the inputs and the orphans will be replaced in the cloned graph by copies
available in the equiv dictionary returned by the function (copy_inputs_and_orphans
defaults to True)
:rtype: a dictionary
:return:
equiv mapping each L{Variable} and L{Op} in the graph delimited by i and o to a copy
(akin to deepcopy's memo).
""" """
if memo is None: if memo is None:
memo = {} memo = {}
for input in i: # clone the inputs if necessary
for input in inputs:
if copy_inputs_and_orphans: if copy_inputs_and_orphans:
cpy = input.clone() cpy = input.clone()
cpy.owner = None cpy.owner = None
...@@ -606,7 +614,8 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans=True, memo=None): ...@@ -606,7 +614,8 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans=True, memo=None):
else: else:
memo.setdefault(input, input) memo.setdefault(input, input)
for apply in io_toposort(i, o): # go through the inputs -> outputs graph cloning as we go
for apply in io_toposort(inputs, outputs):
for input in apply.inputs: for input in apply.inputs:
if input not in memo: if input not in memo:
if copy_inputs_and_orphans: if copy_inputs_and_orphans:
...@@ -620,12 +629,14 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans=True, memo=None): ...@@ -620,12 +629,14 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans=True, memo=None):
for output, new_output in zip(apply.outputs, new_apply.outputs): for output, new_output in zip(apply.outputs, new_apply.outputs):
memo.setdefault(output, new_output) memo.setdefault(output, new_output)
for output in o: # finish up by cloning any remaining outputs (it can happen)
for output in outputs:
if output not in memo: if output not in memo:
memo[output] = output.clone() memo[output] = output.clone()
return memo return memo
def general_toposort(r_out, deps, debug_print = False): def general_toposort(r_out, deps, debug_print = False):
"""WRITEME """WRITEME
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论