提交 3c5e02d2 authored 作者: Frederic's avatar Frederic

Lower triple recursion to dual recursion: remove __prune__

The recursion was: __remove_clients__ -> __prune_r__ -> __prune__ This help gh-3341 This should also speed up change to the graph, but I didn't do any timming.
上级 4be377bb
......@@ -454,9 +454,20 @@ class FunctionGraph(utils.object2):
the variable will stay in fgraph.variables.
"""
# Prunes the owners of the variables.
# Prunes the owners(Apply Node) of the variables.
if variable.owner:
self.__prune__(variable.owner, reason)
apply_node = variable.owner
used_or_output = [output for output in apply_node.outputs
if output.clients or output in self.outputs]
# If the apply node is not used and is not an output
if not used_or_output:
self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)],
reason=reason)
# variable should not have any clients.
# assert not variable.clients
......@@ -478,29 +489,6 @@ class FunctionGraph(utils.object2):
# or not.
del variable.fgraph
def __prune__(self, apply_node, reason=None):
"""
Always called on owner of pruned variable from the graph.
This do not mean we will remove it from the graph. If other
outputs are still used, we will keep the node in the graph.
"""
# If apply_node's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op.
for output in apply_node.outputs:
# Cannot prune an op which is an output or used somewhere
if output.clients or output in self.outputs:
return
self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(apply_node.inputs)
# change input #
def change_input(self, node, i, new_r, reason=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论