提交 415ae91a authored 作者: Frederic's avatar Frederic

Interface change: add param reason to env_reafure.on_import().

上级 5df9e21d
...@@ -1428,6 +1428,8 @@ class _VariableEquivalenceTracker(object): ...@@ -1428,6 +1428,8 @@ class _VariableEquivalenceTracker(object):
self.reasons = {} self.reasons = {}
self.replaced_by = {} self.replaced_by = {}
self.event_list = [] self.event_list = []
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
assert fgraph is self.fgraph assert fgraph is self.fgraph
...@@ -1442,8 +1444,9 @@ class _VariableEquivalenceTracker(object): ...@@ -1442,8 +1444,9 @@ class _VariableEquivalenceTracker(object):
self.active_nodes.remove(node) self.active_nodes.remove(node)
self.inactive_nodes.add(node) self.inactive_nodes.add(node)
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('import', node)) self.event_list.append(_FunctionGraphEvent('import', node,
reason=reason))
#print 'NEW NODE', node, id(node) #print 'NEW NODE', node, id(node)
assert node not in self.active_nodes assert node not in self.active_nodes
......
...@@ -380,7 +380,7 @@ if 0: ...@@ -380,7 +380,7 @@ if 0:
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def on_import(self, fgraph, app): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import") #if app in self.debug_all_apps: raise ProtocolError("double import")
...@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper):
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def on_import(self, fgraph, app): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import") if app in self.debug_all_apps: raise ProtocolError("double import")
......
...@@ -107,7 +107,7 @@ class FunctionGraph(utils.object2): ...@@ -107,7 +107,7 @@ class FunctionGraph(utils.object2):
self.__setup_r__(input) self.__setup_r__(input)
self.variables.add(input) self.variables.add(input)
self.__import_r__(outputs) self.__import_r__(outputs, reason="init")
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
output.clients.append(('output', i)) output.clients.append(('output', i))
...@@ -217,7 +217,7 @@ class FunctionGraph(utils.object2): ...@@ -217,7 +217,7 @@ class FunctionGraph(utils.object2):
return False return False
### import ### ### import ###
def __import_r__(self, variables): def __import_r__(self, variables, reason):
global NullType global NullType
if NullType is None: if NullType is None:
from null_type import NullType from null_type import NullType
...@@ -226,7 +226,7 @@ class FunctionGraph(utils.object2): ...@@ -226,7 +226,7 @@ class FunctionGraph(utils.object2):
for apply_node in [r.owner for r in variables if r.owner is not None]: for apply_node in [r.owner for r in variables if r.owner is not None]:
if apply_node not in r_owner_done: if apply_node not in r_owner_done:
r_owner_done.add(apply_node) r_owner_done.add(apply_node)
self.__import__(apply_node) self.__import__(apply_node, reason=reason)
for r in variables: for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type, NullType): if isinstance(r.type, NullType):
...@@ -237,7 +237,7 @@ class FunctionGraph(utils.object2): ...@@ -237,7 +237,7 @@ class FunctionGraph(utils.object2):
self.__setup_r__(r) self.__setup_r__(r)
self.variables.add(r) self.variables.add(r)
def __import__(self, apply_node, check=True): def __import__(self, apply_node, check=True, reason=None):
node = apply_node node = apply_node
# We import the nodes in topological order. We only are interested # We import the nodes in topological order. We only are interested
...@@ -335,7 +335,7 @@ class FunctionGraph(utils.object2): ...@@ -335,7 +335,7 @@ class FunctionGraph(utils.object2):
self.variables.add(input) self.variables.add(input)
self.__add_clients__(input, [(node, i)]) self.__add_clients__(input, [(node, i)])
assert node.fgraph is self assert node.fgraph is self
self.execute_callbacks('on_import', node) self.execute_callbacks('on_import', node, reason)
### prune ### ### prune ###
def __prune_r__(self, variables, reason=None): def __prune_r__(self, variables, reason=None):
...@@ -400,7 +400,7 @@ class FunctionGraph(utils.object2): ...@@ -400,7 +400,7 @@ class FunctionGraph(utils.object2):
if r is new_r: if r is new_r:
return return
self.__import_r__([new_r]) self.__import_r__([new_r], reason=reason)
self.__add_clients__(new_r, [(node, i)]) self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False) prune = self.__remove_clients__(r, [(node, i)], False)
# Precondition: the substitution is semantically valid # Precondition: the substitution is semantically valid
......
...@@ -421,7 +421,7 @@ class MergeFeature(object): ...@@ -421,7 +421,7 @@ class MergeFeature(object):
self.blacklist = [] self.blacklist = []
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node) self.on_import(fgraph, node, "on_attach")
def on_change_input(self, fgraph, node, i, r, new_r): def on_change_input(self, fgraph, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct # If inputs to node change, it is not guaranteed that it is distinct
...@@ -433,7 +433,7 @@ class MergeFeature(object): ...@@ -433,7 +433,7 @@ class MergeFeature(object):
if isinstance(new_r, graph.Constant): if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r) self.process_constant(fgraph, new_r)
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
for c in node.inputs: for c in node.inputs:
if isinstance(c, graph.Constant): if isinstance(c, graph.Constant):
self.process_constant(fgraph, c) self.process_constant(fgraph, c)
...@@ -1165,7 +1165,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1165,7 +1165,7 @@ class NavigatorOptimizer(Optimizer):
class Updater: class Updater:
if importer is not None: if importer is not None:
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
importer(node) importer(node)
if pruner is not None: if pruner is not None:
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
...@@ -1357,7 +1357,7 @@ class ChangeTracker: ...@@ -1357,7 +1357,7 @@ class ChangeTracker:
def __init__(self): def __init__(self):
self.changed = False self.changed = False
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
self.changed = True self.changed = True
def on_change_input(self, fgraph, node, i, r, new_r): def on_change_input(self, fgraph, node, i, r, new_r):
......
...@@ -55,7 +55,7 @@ class Feature(object): ...@@ -55,7 +55,7 @@ class Feature(object):
functionality that it installed into the function_graph. functionality that it installed into the function_graph.
""" """
def on_import(self, function_graph, node): def on_import(self, function_graph, node, reason):
""" """
Called whenever a node is imported into function_graph, which is Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph. just before the node is actually connected to the graph.
...@@ -96,7 +96,7 @@ class Bookkeeper(Feature): ...@@ -96,7 +96,7 @@ class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_import(fgraph, node) self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
...@@ -265,7 +265,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -265,7 +265,7 @@ class NodeFinder(dict, Bookkeeper):
del fgraph.get_nodes del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph) Bookkeeper.on_detach(self, fgraph)
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
try: try:
self.setdefault(node.op, []).append(node) self.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
...@@ -310,13 +310,13 @@ class PrintListener(Feature): ...@@ -310,13 +310,13 @@ class PrintListener(Feature):
if self.active: if self.active:
print "-- detaching from: ", fgraph print "-- detaching from: ", fgraph
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
if self.active: if self.active:
print "-- importing: %s" % node print "-- importing: %s, reason: %s" % (node, reason)
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
if self.active: if self.active:
print "-- pruning: %s" % node print "-- pruning: %s, reason: %s" % (node, reason)
def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active: if self.active:
......
...@@ -137,9 +137,9 @@ class HintsFeature(object): ...@@ -137,9 +137,9 @@ class HintsFeature(object):
# Variable -> tuple(scalars) or None (All tensor vars map to tuple) # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.hints = {} self.hints = {}
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node) self.on_import(fgraph, node, "on_attach")
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.hints: if node.outputs[0] in self.hints:
# this is a revert, not really an import # this is a revert, not really an import
for r in node.outputs + node.inputs: for r in node.outputs + node.inputs:
......
...@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph # shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately # It will call infer_shape and set_shape appropriately
dummy_fgraph = None dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner) shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = [] ret = []
for o in outs: for o in outs:
......
...@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph # shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately # It will call infer_shape and set_shape appropriately
dummy_fgraph = None dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner) shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = [] ret = []
for o in outs: for o in outs:
......
...@@ -980,9 +980,9 @@ class ShapeFeature(object): ...@@ -980,9 +980,9 @@ class ShapeFeature(object):
# shape var -> graph v # shape var -> graph v
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node) self.on_import(fgraph, node, reason='on_attach')
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.shape_of: if node.outputs[0] in self.shape_of:
# this is a revert, not really an import # this is a revert, not really an import
for r in node.outputs + node.inputs: for r in node.outputs + node.inputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论