提交 29732d84 authored 作者: abergeron's avatar abergeron

Merge pull request #3604 from nouiz/recur

[ENH] Lower triple recursion to simple recursion
...@@ -281,24 +281,72 @@ class FunctionGraph(utils.object2): ...@@ -281,24 +281,72 @@ class FunctionGraph(utils.object2):
""" """
Removes all from the clients list of r. Removes all from the clients list of r.
WRITEME This is the main method to remove variable or apply node from
an FunctionGraph.
If called with an empty list of clients and prune=True, this
will remove the owner of the variable (so an apply_node).
Parameters Parameters
---------- ----------
r r : Variable
Variable. The clients of r will be removed.
clients_to_remove clients_to_remove : List of (op, i) pairs
List of (op, i) pairs such that node.inputs[i] is not r anymore. List of (op, i) pairs such that node.inputs[i] is not r anymore.
prune : bool
If prune is True, it remove r from this fgraph if it don't
have clients left.
Returns
-------
bool
True if r is still in the fgraph and need to be pruned
later. This can happen only when prune is False. A second
call to this method with an empty list for
clients_to_remove and prune=True will remove r.
""" """
for entry in clients_to_remove: for entry in clients_to_remove:
r.clients.remove(entry) r.clients.remove(entry)
assert entry not in r.clients # an op,i pair should be unique assert entry not in r.clients # an op,i pair should be unique
if not r.clients: if r.clients:
if prune: return False
self.__prune_r__(r, reason) if not prune:
return False
return True return True
variable = r
if variable.owner:
apply_node = variable.owner
used_or_output = [output for output in apply_node.outputs
if output.clients or output in self.outputs]
# If the apply node is not used and is not an output
if not used_or_output:
self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)],
reason=reason)
# variable should not have any clients.
# assert not variable.clients
# variable should be in self.variables
# Why this assert fail? Making it True could cause opt speed up
# I think this is caused as we remove var in self.variables in
# another place.
# assert variable in self.variables
if variable in self.variables:
# If the owner have other outputs still used,
# then we must keep that variable in the graph.
if not variable.owner or not any(
[var for var in variable.owner.outputs
if var.clients]):
self.variables.remove(variable)
# This allow to quickly know if a var is still in the fgraph
# or not.
del variable.fgraph
return False return False
# import # # import #
...@@ -443,65 +491,6 @@ class FunctionGraph(utils.object2): ...@@ -443,65 +491,6 @@ class FunctionGraph(utils.object2):
assert node.fgraph is self assert node.fgraph is self
self.execute_callbacks('on_import', node, reason) self.execute_callbacks('on_import', node, reason)
# prune #
def __prune_r__(self, variable, reason=None):
"""
Should be called for variable that aren't used anymore:
len(var.clients) == 0.
This do not mean we will remove it from fgraph.variables. If
the owner stay in the fgraph as other outputs are still used,
the variable will stay in fgraph.variables.
"""
# Prunes the owners of the variables.
if variable.owner:
self.__prune__(variable.owner, reason)
# variable should not have any clients.
# assert not variable.clients
# variable should be in self.variables
# Why this assert fail? Making it True could cause opt speed up
# I think this is caused as we remove var in self.variables in
# another place.
# assert variable in self.variables
if variable in self.variables:
# If the owner have other outputs still used,
# then we must keep that variable in the graph.
if not variable.owner or not any(
[var for var in variable.owner.outputs
if var.clients]):
self.variables.remove(variable)
# This allow to quickly know if a var is still in the fgraph
# or not.
del variable.fgraph
def __prune__(self, apply_node, reason=None):
"""
Always called on owner of pruned variable from the graph.
This do not mean we will remove it from the graph. If other
outputs are still used, we will keep the node in the graph.
"""
# If apply_node's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op.
for output in apply_node.outputs:
# Cannot prune an op which is an output or used somewhere
if output.clients or output in self.outputs:
return
self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason)
for i, input in enumerate(apply_node.inputs):
self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(apply_node.inputs)
# change input # # change input #
def change_input(self, node, i, new_r, reason=None): def change_input(self, node, i, new_r, reason=None):
""" """
...@@ -546,9 +535,8 @@ class FunctionGraph(utils.object2): ...@@ -546,9 +535,8 @@ class FunctionGraph(utils.object2):
# transaction will be reverted later. # transaction will be reverted later.
self.execute_callbacks('on_change_input', node, i, self.execute_callbacks('on_change_input', node, i,
r, new_r, reason=reason) r, new_r, reason=reason)
if prune: if prune:
self.__prune_r__(r, reason=reason) self.__remove_clients__(r, [], True)
# replace # # replace #
def replace(self, r, new_r, reason=None, verbose=None): def replace(self, r, new_r, reason=None, verbose=None):
......
...@@ -375,8 +375,8 @@ class _tensor_py_operators(object): ...@@ -375,8 +375,8 @@ class _tensor_py_operators(object):
If `target` is `'cpu'` this will transfer to a TensorType (if If `target` is `'cpu'` this will transfer to a TensorType (if
not already one). Other types may define additional targets. not already one). Other types may define additional targets.
Paramters Parameters
--------- ----------
target : str target : str
The desired location of the output variable The desired location of the output variable
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论