提交 712c53a6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unnecessary Constant cloning from clone_get_equiv

Since it is no longer necessary to clone `Constant`s, and they add extra work for the merge rewrites, `clone_get_equiv`'s `copy_orphans` option has been prevented from cloning `Constant`s. The option is mostly applicable to constants, but it has been retained for other non-`Constant` cases (just in case) and backward compatibility.
上级 77396f7f
......@@ -1061,25 +1061,29 @@ def clone_get_equiv(
] = None,
clone_inner_graphs: bool = False,
) -> Dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]:
r"""
Return a dictionary that maps from `Variable` and `Apply` nodes in the
original graph to a new node (a clone) in a new graph.
r"""Clone the graph between `inputs` and `outputs` and return a map of the cloned objects.
This function works by recursively cloning inputs and rebuilding a directed
graph from the inputs up.
This function works by recursively cloning inputs... rebuilding a directed
graph from the inputs up to eventually building new outputs.
If `memo` already contains entries for some of the objects in the graph,
those objects are replaced with their values in `memo` and *not*
unnecessarily cloned.
Parameters
----------
inputs
Inputs of the graph to be cloned.
outputs
Outputs of the graph to be cloned.
copy_inputs
True means to create the cloned graph from new input
nodes (the bottom of a feed-upward graph).
False means to clone a graph that is rooted at the original input
nodes.
``True`` means to create the cloned graph from cloned input nodes.
``False`` means to clone a graph that is rooted at the original input
nodes. `Constant`\s are *not* cloned.
copy_orphans
When ``True``, new constant nodes are created. When ``False``, original
constant nodes are reused in the new graph.
When ``True``, inputs with no owners are cloned. When ``False``,
original inputs are reused in the new graph. Cloning is *not*
performed for `Constant`\s.
memo
Optionally start with a partly-filled dictionary for the return value.
If a dictionary is passed, this function will work in-place on that
......@@ -1093,7 +1097,7 @@ def clone_get_equiv(
# clone the inputs if necessary
for input in inputs:
if copy_inputs:
if not isinstance(input, Constant) and copy_inputs:
cpy = input.clone()
cpy.owner = None
cpy.index = None
......@@ -1105,7 +1109,7 @@ def clone_get_equiv(
for apply in io_toposort(inputs, outputs):
for input in apply.inputs:
if input not in memo:
if copy_orphans:
if not isinstance(input, Constant) and copy_orphans:
cpy = input.clone()
memo[input] = cpy
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论