提交 42cf211b authored 作者: Frederic's avatar Frederic

Add comment and rename var

上级 32a4afb7
...@@ -375,6 +375,14 @@ class FunctionGraph(utils.object2): ...@@ -375,6 +375,14 @@ class FunctionGraph(utils.object2):
### prune ### ### prune ###
def __prune_r__(self, variables, reason=None): def __prune_r__(self, variables, reason=None):
"""Should be called for variables that aren't used anymore:
len(var.clients) == 0
This do not mean we will remove it from fgraph.variables. If
the owner stay in the fgraph as other outputs are still used,
the variable will be stay in fgraph.variables.
"""
# Prunes the owners of the variables. # Prunes the owners of the variables.
for node in set(r.owner for r in variables if r.owner is not None): for node in set(r.owner for r in variables if r.owner is not None):
self.__prune__(node, reason) self.__prune__(node, reason)
...@@ -383,25 +391,31 @@ class FunctionGraph(utils.object2): ...@@ -383,25 +391,31 @@ class FunctionGraph(utils.object2):
self.variables.remove(r) self.variables.remove(r)
def __prune__(self, apply_node, reason=None): def __prune__(self, apply_node, reason=None):
node = apply_node """Always called on owner of pruned variable from the graph.
if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node) This do not mean we will remove it from the graph. If other
assert node.fgraph is self outputs are still used, we will keep the node in the graph.
# If node's outputs have no clients, removes it from the graph
"""
if apply_node not in self.apply_nodes:
raise Exception(
"%s does not belong to this FunctionGraph." % apply_node)
assert apply_node.fgraph is self
# If apply_node's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one # and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client # of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op. # then __prune__ is a no-op.
for output in node.outputs: for output in apply_node.outputs:
# Cannot prune an op which is an output or used somewhere # Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: # output in self.outputs or self.clients(output): if output.clients or output in self.outputs:
return return
self.apply_nodes.remove(node) self.apply_nodes.remove(apply_node)
self.variables.difference_update(node.outputs) self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', node, reason) self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(node.inputs): for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(node, i)], reason=reason) self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(node.inputs) # self.__prune_r__(apply_node.inputs)
### change input ### ### change input ###
def change_input(self, node, i, new_r, reason=None): def change_input(self, node, i, new_r, reason=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论