提交 415ae91a authored 作者: Frederic's avatar Frederic

Interface change: add param reason to env_reafure.on_import().

上级 5df9e21d
......@@ -1428,6 +1428,8 @@ class _VariableEquivalenceTracker(object):
self.reasons = {}
self.replaced_by = {}
self.event_list = []
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph):
assert fgraph is self.fgraph
......@@ -1442,8 +1444,9 @@ class _VariableEquivalenceTracker(object):
self.active_nodes.remove(node)
self.inactive_nodes.add(node)
def on_import(self, fgraph, node):
self.event_list.append(_FunctionGraphEvent('import', node))
def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('import', node,
reason=reason))
#print 'NEW NODE', node, id(node)
assert node not in self.active_nodes
......
......@@ -380,7 +380,7 @@ if 0:
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def on_import(self, fgraph, app):
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
......@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper):
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def on_import(self, fgraph, app):
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import")
......
......@@ -107,7 +107,7 @@ class FunctionGraph(utils.object2):
self.__setup_r__(input)
self.variables.add(input)
self.__import_r__(outputs)
self.__import_r__(outputs, reason="init")
for i, output in enumerate(outputs):
output.clients.append(('output', i))
......@@ -217,7 +217,7 @@ class FunctionGraph(utils.object2):
return False
### import ###
def __import_r__(self, variables):
def __import_r__(self, variables, reason):
global NullType
if NullType is None:
from null_type import NullType
......@@ -226,7 +226,7 @@ class FunctionGraph(utils.object2):
for apply_node in [r.owner for r in variables if r.owner is not None]:
if apply_node not in r_owner_done:
r_owner_done.add(apply_node)
self.__import__(apply_node)
self.__import__(apply_node, reason=reason)
for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type, NullType):
......@@ -237,7 +237,7 @@ class FunctionGraph(utils.object2):
self.__setup_r__(r)
self.variables.add(r)
def __import__(self, apply_node, check=True):
def __import__(self, apply_node, check=True, reason=None):
node = apply_node
# We import the nodes in topological order. We only are interested
......@@ -335,7 +335,7 @@ class FunctionGraph(utils.object2):
self.variables.add(input)
self.__add_clients__(input, [(node, i)])
assert node.fgraph is self
self.execute_callbacks('on_import', node)
self.execute_callbacks('on_import', node, reason)
### prune ###
def __prune_r__(self, variables, reason=None):
......@@ -400,7 +400,7 @@ class FunctionGraph(utils.object2):
if r is new_r:
return
self.__import_r__([new_r])
self.__import_r__([new_r], reason=reason)
self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False)
# Precondition: the substitution is semantically valid
......
......@@ -421,7 +421,7 @@ class MergeFeature(object):
self.blacklist = []
for node in fgraph.toposort():
self.on_import(fgraph, node)
self.on_import(fgraph, node, "on_attach")
def on_change_input(self, fgraph, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct
......@@ -433,7 +433,7 @@ class MergeFeature(object):
if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r)
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
for c in node.inputs:
if isinstance(c, graph.Constant):
self.process_constant(fgraph, c)
......@@ -1165,7 +1165,7 @@ class NavigatorOptimizer(Optimizer):
class Updater:
if importer is not None:
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
importer(node)
if pruner is not None:
def on_prune(self, fgraph, node, reason):
......@@ -1357,7 +1357,7 @@ class ChangeTracker:
def __init__(self):
self.changed = False
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
self.changed = True
def on_change_input(self, fgraph, node, i, r, new_r):
......
......@@ -55,7 +55,7 @@ class Feature(object):
functionality that it installed into the function_graph.
"""
def on_import(self, function_graph, node):
def on_import(self, function_graph, node, reason):
"""
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
......@@ -96,7 +96,7 @@ class Bookkeeper(Feature):
def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_import(fgraph, node)
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
......@@ -265,7 +265,7 @@ class NodeFinder(dict, Bookkeeper):
del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph)
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
try:
self.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
......@@ -310,13 +310,13 @@ class PrintListener(Feature):
if self.active:
print "-- detaching from: ", fgraph
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
if self.active:
print "-- importing: %s" % node
print "-- importing: %s, reason: %s" % (node, reason)
def on_prune(self, fgraph, node, reason):
if self.active:
print "-- pruning: %s" % node
print "-- pruning: %s, reason: %s" % (node, reason)
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active:
......
......@@ -137,9 +137,9 @@ class HintsFeature(object):
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.hints = {}
for node in fgraph.toposort():
self.on_import(fgraph, node)
self.on_import(fgraph, node, "on_attach")
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.hints:
# this is a revert, not really an import
for r in node.outputs + node.inputs:
......
......@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner)
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = []
for o in outs:
......
......@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner)
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = []
for o in outs:
......
......@@ -980,9 +980,9 @@ class ShapeFeature(object):
# shape var -> graph v
for node in fgraph.toposort():
self.on_import(fgraph, node)
self.on_import(fgraph, node, reason='on_attach')
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.shape_of:
# this is a revert, not really an import
for r in node.outputs + node.inputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论