提交 5df9e21d authored 作者: Frederic's avatar Frederic

Interface change. env feature on_prune() fct now take a reason parameter

上级 9755603b
...@@ -1433,8 +1433,9 @@ class _VariableEquivalenceTracker(object): ...@@ -1433,8 +1433,9 @@ class _VariableEquivalenceTracker(object):
assert fgraph is self.fgraph assert fgraph is self.fgraph
self.fgraph = None self.fgraph = None
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('prune', node)) self.event_list.append(_FunctionGraphEvent('prune', node,
reason=reason))
#print 'PRUNING NODE', node, id(node) #print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes assert node in self.active_nodes
assert node not in self.inactive_nodes assert node not in self.inactive_nodes
......
...@@ -410,7 +410,7 @@ if 0: ...@@ -410,7 +410,7 @@ if 0:
self.stale_droot = True self.stale_droot = True
def on_prune(self, fgraph, app): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import") #if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app) #self.debug_all_apps.remove(app)
...@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self.stale_droot = True self.stale_droot = True
def on_prune(self, fgraph, app): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import") if app not in self.debug_all_apps: raise ProtocolError("prune without import")
self.debug_all_apps.remove(app) self.debug_all_apps.remove(app)
......
...@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception ...@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
types that it can raise types that it can raise
""" """
import sys import sys
from theano.gof import graph from theano.gof import graph
from theano.gof import utils from theano.gof import utils
from theano.gof import toolbox from theano.gof import toolbox
...@@ -193,7 +194,8 @@ class FunctionGraph(utils.object2): ...@@ -193,7 +194,8 @@ class FunctionGraph(utils.object2):
assert not set(r.clients).intersection(set(new_clients)) assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients r.clients += new_clients
def __remove_clients__(self, r, clients_to_remove, prune=True): def __remove_clients__(self, r, clients_to_remove,
prune=True, reason=None):
""" WRITEME """ WRITEME
r -> variable r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore. clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
...@@ -209,7 +211,7 @@ class FunctionGraph(utils.object2): ...@@ -209,7 +211,7 @@ class FunctionGraph(utils.object2):
assert entry not in r.clients # an op,i pair should be unique assert entry not in r.clients # an op,i pair should be unique
if not r.clients: if not r.clients:
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r], reason)
return False return False
return True return True
return False return False
...@@ -336,15 +338,15 @@ class FunctionGraph(utils.object2): ...@@ -336,15 +338,15 @@ class FunctionGraph(utils.object2):
self.execute_callbacks('on_import', node) self.execute_callbacks('on_import', node)
### prune ### ### prune ###
def __prune_r__(self, variables): def __prune_r__(self, variables, reason=None):
# Prunes the owners of the variables. # Prunes the owners of the variables.
for node in set(r.owner for r in variables if r.owner is not None): for node in set(r.owner for r in variables if r.owner is not None):
self.__prune__(node) self.__prune__(node, reason)
for r in variables: for r in variables:
if not r.clients and r in self.variables: if not r.clients and r in self.variables:
self.variables.remove(r) self.variables.remove(r)
def __prune__(self, apply_node): def __prune__(self, apply_node, reason=None):
node = apply_node node = apply_node
if node not in self.apply_nodes: if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node) raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node)
...@@ -359,10 +361,10 @@ class FunctionGraph(utils.object2): ...@@ -359,10 +361,10 @@ class FunctionGraph(utils.object2):
return return
self.apply_nodes.remove(node) self.apply_nodes.remove(node)
self.variables.difference_update(node.outputs) self.variables.difference_update(node.outputs)
self.execute_callbacks('on_prune', node) self.execute_callbacks('on_prune', node, reason)
for i, input in enumerate(node.inputs): for i, input in enumerate(node.inputs):
self.__remove_clients__(input, [(node, i)]) self.__remove_clients__(input, [(node, i)], reason=reason)
#self.__prune_r__(node.inputs) #self.__prune_r__(node.inputs)
### change input ### ### change input ###
...@@ -408,8 +410,7 @@ class FunctionGraph(utils.object2): ...@@ -408,8 +410,7 @@ class FunctionGraph(utils.object2):
r, new_r, reason=reason) r, new_r, reason=reason)
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r], reason=reason)
### replace ### ### replace ###
......
...@@ -440,7 +440,7 @@ class MergeFeature(object): ...@@ -440,7 +440,7 @@ class MergeFeature(object):
self.process_node(fgraph, node) self.process_node(fgraph, node)
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node) self.nodes_seen.discard(node)
for c in node.inputs: for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1): if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
...@@ -1168,7 +1168,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1168,7 +1168,7 @@ class NavigatorOptimizer(Optimizer):
def on_import(self, fgraph, node): def on_import(self, fgraph, node):
importer(node) importer(node)
if pruner is not None: if pruner is not None:
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
pruner(node) pruner(node)
if chin is not None: if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r): def on_change_input(self, fgraph, node, i, r, new_r):
......
...@@ -64,7 +64,7 @@ class Feature(object): ...@@ -64,7 +64,7 @@ class Feature(object):
you should do this by implementing on_attach. you should do this by implementing on_attach.
""" """
def on_prune(self, function_graph, node): def on_prune(self, function_graph, node, reason):
""" """
Called whenever a node is pruned (removed) from the function_graph, Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph. after it is disconnected from the graph.
...@@ -100,7 +100,7 @@ class Bookkeeper(Feature): ...@@ -100,7 +100,7 @@ class Bookkeeper(Feature):
def on_detach(self, fgraph): def on_detach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_prune(fgraph, node) self.on_prune(fgraph, node, 'Bookkeeper.detach')
class History(Feature): class History(Feature):
...@@ -278,7 +278,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -278,7 +278,7 @@ class NodeFinder(dict, Bookkeeper):
print >> sys.stderr, 'OFFENDING node not hashable' print >> sys.stderr, 'OFFENDING node not hashable'
raise e raise e
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
try: try:
nodes = self[node.op] nodes = self[node.op]
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
...@@ -314,7 +314,7 @@ class PrintListener(Feature): ...@@ -314,7 +314,7 @@ class PrintListener(Feature):
if self.active: if self.active:
print "-- importing: %s" % node print "-- importing: %s" % node
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
if self.active: if self.active:
print "-- pruning: %s" % node print "-- pruning: %s" % node
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论