提交 5df9e21d authored 作者: Frederic's avatar Frederic

Interface change. env feature on_prune() fct now take a reason parameter

上级 9755603b
......@@ -1433,8 +1433,9 @@ class _VariableEquivalenceTracker(object):
assert fgraph is self.fgraph
self.fgraph = None
def on_prune(self, fgraph, node):
self.event_list.append(_FunctionGraphEvent('prune', node))
def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('prune', node,
reason=reason))
#print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes
assert node not in self.inactive_nodes
......
......@@ -410,7 +410,7 @@ if 0:
self.stale_droot = True
def on_prune(self, fgraph, app):
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app)
......@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self.stale_droot = True
def on_prune(self, fgraph, app):
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import")
self.debug_all_apps.remove(app)
......
......@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
types that it can raise
"""
import sys
from theano.gof import graph
from theano.gof import utils
from theano.gof import toolbox
......@@ -193,7 +194,8 @@ class FunctionGraph(utils.object2):
assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients
def __remove_clients__(self, r, clients_to_remove, prune=True):
def __remove_clients__(self, r, clients_to_remove,
prune=True, reason=None):
""" WRITEME
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
......@@ -209,7 +211,7 @@ class FunctionGraph(utils.object2):
assert entry not in r.clients # an op,i pair should be unique
if not r.clients:
if prune:
self.__prune_r__([r])
self.__prune_r__([r], reason)
return False
return True
return False
......@@ -336,15 +338,15 @@ class FunctionGraph(utils.object2):
self.execute_callbacks('on_import', node)
### prune ###
def __prune_r__(self, variables):
def __prune_r__(self, variables, reason=None):
# Prunes the owners of the variables.
for node in set(r.owner for r in variables if r.owner is not None):
self.__prune__(node)
self.__prune__(node, reason)
for r in variables:
if not r.clients and r in self.variables:
self.variables.remove(r)
def __prune__(self, apply_node):
def __prune__(self, apply_node, reason=None):
node = apply_node
if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node)
......@@ -359,10 +361,10 @@ class FunctionGraph(utils.object2):
return
self.apply_nodes.remove(node)
self.variables.difference_update(node.outputs)
self.execute_callbacks('on_prune', node)
self.execute_callbacks('on_prune', node, reason)
for i, input in enumerate(node.inputs):
self.__remove_clients__(input, [(node, i)])
self.__remove_clients__(input, [(node, i)], reason=reason)
#self.__prune_r__(node.inputs)
### change input ###
......@@ -408,8 +410,7 @@ class FunctionGraph(utils.object2):
r, new_r, reason=reason)
if prune:
self.__prune_r__([r])
self.__prune_r__([r], reason=reason)
### replace ###
......
......@@ -440,7 +440,7 @@ class MergeFeature(object):
self.process_node(fgraph, node)
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
......@@ -1168,7 +1168,7 @@ class NavigatorOptimizer(Optimizer):
def on_import(self, fgraph, node):
importer(node)
if pruner is not None:
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
pruner(node)
if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r):
......
......@@ -64,7 +64,7 @@ class Feature(object):
you should do this by implementing on_attach.
"""
def on_prune(self, function_graph, node):
def on_prune(self, function_graph, node, reason):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
......@@ -100,7 +100,7 @@ class Bookkeeper(Feature):
def on_detach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_prune(fgraph, node)
self.on_prune(fgraph, node, 'Bookkeeper.detach')
class History(Feature):
......@@ -278,7 +278,7 @@ class NodeFinder(dict, Bookkeeper):
print >> sys.stderr, 'OFFENDING node not hashable'
raise e
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
try:
nodes = self[node.op]
except TypeError: # node.op is unhashable
......@@ -314,7 +314,7 @@ class PrintListener(Feature):
if self.active:
print "-- importing: %s" % node
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
if self.active:
print "-- pruning: %s" % node
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论