提交 42cf211b authored 作者: Frederic's avatar Frederic

Add comment and rename var

上级 32a4afb7
......@@ -375,6 +375,14 @@ class FunctionGraph(utils.object2):
### prune ###
def __prune_r__(self, variables, reason=None):
"""Should be called for variables that aren't used anymore:
len(var.clients) == 0
This do not mean we will remove it from fgraph.variables. If
the owner stay in the fgraph as other outputs are still used,
the variable will be stay in fgraph.variables.
"""
# Prunes the owners of the variables.
for node in set(r.owner for r in variables if r.owner is not None):
self.__prune__(node, reason)
......@@ -383,25 +391,31 @@ class FunctionGraph(utils.object2):
self.variables.remove(r)
def __prune__(self, apply_node, reason=None):
node = apply_node
if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node)
assert node.fgraph is self
# If node's outputs have no clients, removes it from the graph
"""Always called on owner of pruned variable from the graph.
This do not mean we will remove it from the graph. If other
outputs are still used, we will keep the node in the graph.
"""
if apply_node not in self.apply_nodes:
raise Exception(
"%s does not belong to this FunctionGraph." % apply_node)
assert apply_node.fgraph is self
# If apply_node's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op.
for output in node.outputs:
for output in apply_node.outputs:
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: # output in self.outputs or self.clients(output):
if output.clients or output in self.outputs:
return
self.apply_nodes.remove(node)
self.variables.difference_update(node.outputs)
self.execute_callbacks('on_prune', node, reason)
self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(node.inputs):
self.__remove_clients__(input, [(node, i)], reason=reason)
# self.__prune_r__(node.inputs)
for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(apply_node.inputs)
### change input ###
def change_input(self, node, i, new_r, reason=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论