提交 e78e6cd7 authored 作者: nouiz's avatar nouiz

Merge pull request #921 from goodfeli/doc_topo

NEWS: documented, speed up(50% and renamed _dfs_toposort (now _contains_cycles) Rename FunctionGraph.nodes -> apply_nodes. The old nodes still work, but get warned if used. Made a parent class called Node for Apply and Variable class that are both FunctionGraph nodes
...@@ -62,7 +62,7 @@ purpose, you would set the ``view_map`` field as follows: ...@@ -62,7 +62,7 @@ purpose, you would set the ``view_map`` field as follows:
What this means is that the first output (position 0) is a view of the What this means is that the first output (position 0) is a view of the
first input (position 0). Even though the interface allows a list of first input (position 0). Even though the interface allows a list of
inputs that are a view of a given output, this feature is currently inputs that are viewed by a given output, this feature is currently
unsupported. Here are more examples: unsupported. Here are more examples:
......
...@@ -675,7 +675,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -675,7 +675,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
features=[equivalence_tracker]) features=[equivalence_tracker])
if not accept_inplace: if not accept_inplace:
for node in fgraph.nodes: for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None): if getattr(node.op, 'destroy_map', None):
raise TypeError("Graph must not contain inplace operations", raise TypeError("Graph must not contain inplace operations",
node) node)
......
...@@ -131,7 +131,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -131,7 +131,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs) inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs)
fgraph = gof.fg.FunctionGraph(inputs, outputs) fgraph = gof.fg.FunctionGraph(inputs, outputs)
for node in fgraph.nodes: for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None): if getattr(node.op, 'destroy_map', None):
if not accept_inplace: if not accept_inplace:
raise TypeError("Graph must not contain inplace operations", node, node.op) raise TypeError("Graph must not contain inplace operations", node, node.op)
......
...@@ -520,7 +520,7 @@ class ProfileMode(Mode): ...@@ -520,7 +520,7 @@ class ProfileMode(Mode):
print "Profile of Theano functions memory:" print "Profile of Theano functions memory:"
print "(This check only the output of each apply node. It don't check the temporary memory used by the op in the apply node.)" print "(This check only the output of each apply node. It don't check the temporary memory used by the op in the apply node.)"
nb_skipped = 0 nb_skipped = 0
for fgraph,nodes_mem in fct_memory.iteritems(): for fgraph, nodes_mem in fct_memory.iteritems():
size_sum=sum([sum(val) for key,val in nodes_mem.iteritems()]) size_sum=sum([sum(val) for key,val in nodes_mem.iteritems()])
if size_sum < min_memory_size: if size_sum < min_memory_size:
nb_skipped += 1 nb_skipped += 1
......
...@@ -711,7 +711,7 @@ if 0: # old code still to be ported from ProfileMode ...@@ -711,7 +711,7 @@ if 0: # old code still to be ported from ProfileMode
var_mem[out]=v var_mem[out]=v
print print
print "Profile of Theano functions memory:" print "Profile of Theano functions memory:"
for fgraph,nodes_mem in fct_memory.iteritems(): for fgraph, nodes_mem in fct_memory.iteritems():
print "Theano fct:", [fct for fct in fct_call.keys() if fct.maker.fgraph is fgraph][0].name print "Theano fct:", [fct for fct in fct_call.keys() if fct.maker.fgraph is fgraph][0].name
size_sum=sum([sum(val) for key,val in nodes_mem.iteritems()]) size_sum=sum([sum(val) for key,val in nodes_mem.iteritems()])
print " Max without gc, inplace and view (KB)",size_sum/1024 print " Max without gc, inplace and view (KB)",size_sum/1024
......
...@@ -79,10 +79,10 @@ class FunctionGraph(utils.object2): ...@@ -79,10 +79,10 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set. # so I probably am) this should be a set.
self._features = [] self._features = []
# All nodes in the subgraph defined by inputs and outputs are cached in nodes # All apply nodes in the subgraph defined by inputs and outputs are cached in this field
self.nodes = set() self.apply_nodes = set()
# Ditto for variables # Ditto for variable nodes
self.variables = set() self.variables = set()
self.inputs = list(inputs) self.inputs = list(inputs)
...@@ -151,13 +151,13 @@ class FunctionGraph(utils.object2): ...@@ -151,13 +151,13 @@ class FunctionGraph(utils.object2):
nodes and variables. If there are no features, this should set nodes and variables. If there are no features, this should set
them back to what they were originally. them back to what they were originally.
""" """
for node in self.nodes: for apply_node in self.apply_nodes:
del node.fgraph del apply_node.fgraph
del node.deps del apply_node.deps
for variable in self.variables: for variable in self.variables:
del variable.fgraph del variable.fgraph
del variable.clients del variable.clients
self.nodes = set() self.apply_nodes = set()
self.variables = set() self.variables = set()
self.inputs = None self.inputs = None
self.outputs = None self.outputs = None
...@@ -215,11 +215,11 @@ class FunctionGraph(utils.object2): ...@@ -215,11 +215,11 @@ class FunctionGraph(utils.object2):
if NullType is None: if NullType is None:
from null_type import NullType from null_type import NullType
# Imports the owners of the variables # Imports the owners of the variables
r_owner_done = set(self.nodes) r_owner_done = set(self.apply_nodes)
for node in [r.owner for r in variables if r.owner is not None]: for apply_node in [r.owner for r in variables if r.owner is not None]:
if node not in r_owner_done: if apply_node not in r_owner_done:
r_owner_done.add(node) r_owner_done.add(apply_node)
self.__import__(node) self.__import__(apply_node)
for r in variables: for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type,NullType): if isinstance(r.type,NullType):
...@@ -229,7 +229,9 @@ class FunctionGraph(utils.object2): ...@@ -229,7 +229,9 @@ class FunctionGraph(utils.object2):
self.__setup_r__(r) self.__setup_r__(r)
self.variables.add(r) self.variables.add(r)
def __import__(self, node, check = True): def __import__(self, apply_node, check = True):
node = apply_node
# We import the nodes in topological order. We only are interested # We import the nodes in topological order. We only are interested
# in new nodes, so we use all variables we know of as if they were the input set. # in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to # (the functions in the graph module only use the input set to
...@@ -311,9 +313,9 @@ class FunctionGraph(utils.object2): ...@@ -311,9 +313,9 @@ class FunctionGraph(utils.object2):
r) r)
for node in new_nodes: for node in new_nodes:
assert node not in self.nodes assert node not in self.apply_nodes
self.__setup_node__(node) self.__setup_node__(node)
self.nodes.add(node) self.apply_nodes.add(node)
for output in node.outputs: for output in node.outputs:
self.__setup_r__(output) self.__setup_r__(output)
self.variables.add(output) self.variables.add(output)
...@@ -336,8 +338,9 @@ class FunctionGraph(utils.object2): ...@@ -336,8 +338,9 @@ class FunctionGraph(utils.object2):
if not r.clients and r in self.variables: if not r.clients and r in self.variables:
self.variables.remove(r) self.variables.remove(r)
def __prune__(self, node): def __prune__(self, apply_node):
if node not in self.nodes: node = apply_node
if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node) raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node)
assert node.fgraph is self assert node.fgraph is self
# If node's outputs have no clients, removes it from the graph # If node's outputs have no clients, removes it from the graph
...@@ -348,7 +351,7 @@ class FunctionGraph(utils.object2): ...@@ -348,7 +351,7 @@ class FunctionGraph(utils.object2):
# Cannot prune an op which is an output or used somewhere # Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output): if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return return
self.nodes.remove(node) self.apply_nodes.remove(node)
self.variables.difference_update(node.outputs) self.variables.difference_update(node.outputs)
self.execute_callbacks('on_prune', node) self.execute_callbacks('on_prune', node)
...@@ -446,21 +449,29 @@ class FunctionGraph(utils.object2): ...@@ -446,21 +449,29 @@ class FunctionGraph(utils.object2):
Adds a gof.toolbox.Feature to this function_graph Adds a gof.toolbox.Feature to this function_graph
and triggers its on_attach callback and triggers its on_attach callback
""" """
# Filter out literally identical features
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
#it would be nice if we could require a specific class instead of # Filter out functionally identical features.
#a "workalike" so we could do actual error checking # Features may use their on_attach method to raise
#if not isinstance(feature, toolbox.Feature): # toolbox.AlreadyThere if they detect that some
# raise TypeError("Expected gof.toolbox.Feature instance, got "+\ # installed feature does the same thing already
# str(type(feature)))
attach = getattr(feature, 'on_attach', None) attach = getattr(feature, 'on_attach', None)
if attach is not None: if attach is not None:
try: try:
attach(self) attach(self)
except toolbox.AlreadyThere: except toolbox.AlreadyThere:
return return
#it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature):
# raise TypeError("Expected gof.toolbox.Feature instance, got "+\
# str(type(feature)))
# Add the feature
self._features.append(feature) self._features.append(feature)
def remove_feature(self, feature): def remove_feature(self, feature):
...@@ -490,6 +501,9 @@ class FunctionGraph(utils.object2): ...@@ -490,6 +501,9 @@ class FunctionGraph(utils.object2):
try: try:
fn = getattr(feature, name) fn = getattr(feature, name)
except AttributeError: except AttributeError:
# this is safe because there is no work done inside the
# try; the AttributeError reall must come from feature.${name}
# not existing
continue continue
#####HORRIBLE OPTIONAL ARGUMENT HACK #####HORRIBLE OPTIONAL ARGUMENT HACK
...@@ -532,12 +546,12 @@ class FunctionGraph(utils.object2): ...@@ -532,12 +546,12 @@ class FunctionGraph(utils.object2):
{node: predecessors} where predecessors is a list of nodes {node: predecessors} where predecessors is a list of nodes
that should be computed before the key node. that should be computed before the key node.
""" """
if len(self.nodes) < 2: if len(self.apply_nodes) < 2:
# optimization # optimization
# when there are 0 or 1 nodes, no sorting is necessary # when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker produces # This special case happens a lot because the OpWiseCLinker produces
# 1-element graphs. # 1-element graphs.
return list(self.nodes) return list(self.apply_nodes)
fg = self fg = self
ords = self.orderings() ords = self.orderings()
order = graph.io_toposort(fg.inputs, fg.outputs, ords) order = graph.io_toposort(fg.inputs, fg.outputs, ords)
...@@ -569,26 +583,31 @@ class FunctionGraph(utils.object2): ...@@ -569,26 +583,31 @@ class FunctionGraph(utils.object2):
"""WRITEME Same as len(self.clients(r)).""" """WRITEME Same as len(self.clients(r))."""
return len(self.clients(r)) return len(self.clients(r))
# def edge(self, r): def nodes_getter(self):
# return r in self.inputs or r in self.orphans warnings.warn("FunctionGraph.nodes is deprecated, it has been renamed 'apply_nodes'",
stacklevel=2)
return self.apply_nodes
def nodes_setter(self, value):
warnings.warn("FunctionGraph.nodes is deprecated, it has been renamed 'apply_nodes'",
stacklevel=2)
self.apply_nodes = value
def nodes_deleter(self):
warnings.warn("FunctionGraph.nodes is deprecated, it has been renamed 'apply_nodes'",
stacklevel=2)
del self.apply_nodes
# def follow(self, r): nodes = property(nodes_getter, nodes_setter, nodes_deleter)
# node = r.owner
# if self.edge(r):
# return None
# else:
# if node is None:
# raise Exception("what the fuck")
# return node.inputs
def check_integrity(self): def check_integrity(self):
"""WRITEME """WRITEME
Call this for a diagnosis if things go awry. Call this for a diagnosis if things go awry.
""" """
nodes = graph.ops(self.inputs, self.outputs) nodes = graph.ops(self.inputs, self.outputs)
if self.nodes != nodes: if self.apply_nodes != nodes:
missing = nodes.difference(self.nodes) missing = nodes.difference(self.apply_nodes)
excess = self.nodes.difference(nodes) excess = self.apply_nodes.difference(nodes)
raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess) raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess)
for node in nodes: for node in nodes:
if node.fgraph is not self: if node.fgraph is not self:
......
...@@ -21,7 +21,25 @@ is_same_graph_with_merge = None ...@@ -21,7 +21,25 @@ is_same_graph_with_merge = None
equal_computations = None equal_computations = None
class Apply(utils.object2): class Node(utils.object2):
"""A Node in a theano graph.
Graphs contain two kinds of Nodes--
Variable and Apply.
Edges in the graph are not explicitly represented.
Instead each Node keeps track of its parents via
Variable.owner / Apply.inputs and its children
via Variable.clients / Apply.outputs.
"""
def get_parents(self):
""" Return a list of the parents of this node.
Should return a copy--i.e., modifying the return
value should not modify the graph structure."""
raise NotImplementedError()
class Apply(Node):
""" """
An :term:`Apply` instance is a node in an expression graph which represents the application An :term:`Apply` instance is a node in an expression graph which represents the application
of an `Op` to some input `Variable` nodes, producing some output `Variable` nodes. of an `Op` to some input `Variable` nodes, producing some output `Variable` nodes.
...@@ -202,6 +220,9 @@ class Apply(utils.object2): ...@@ -202,6 +220,9 @@ class Apply(utils.object2):
new_node.inputs = new_inputs new_node.inputs = new_inputs
return new_node return new_node
def get_parents(self):
return list( self.inputs )
#convenience properties #convenience properties
nin = property(lambda self: len(self.inputs), doc='same as len(self.inputs)') nin = property(lambda self: len(self.inputs), doc='same as len(self.inputs)')
"""property: Number of inputs""" """property: Number of inputs"""
...@@ -210,7 +231,7 @@ class Apply(utils.object2): ...@@ -210,7 +231,7 @@ class Apply(utils.object2):
"""property: Number of outputs""" """property: Number of outputs"""
class Variable(utils.object2): class Variable(Node):
""" """
A :term:`Variable` is a node in an expression graph that represents a variable. A :term:`Variable` is a node in an expression graph that represents a variable.
...@@ -364,6 +385,11 @@ class Variable(utils.object2): ...@@ -364,6 +385,11 @@ class Variable(utils.object2):
raise NotImplementedError('Subclasses of Variable must provide __ge__', raise NotImplementedError('Subclasses of Variable must provide __ge__',
self.__class__.__name__) self.__class__.__name__)
def get_parents(self):
if self.owner is not None:
return [ self.owner ]
return [ ]
def env_getter(self): def env_getter(self):
warnings.warn("Variable.env is deprecated, it has been renamed 'fgraph'", warnings.warn("Variable.env is deprecated, it has been renamed 'fgraph'",
stacklevel=2) stacklevel=2)
...@@ -726,13 +752,26 @@ def general_toposort(r_out, deps, debug_print=False): ...@@ -726,13 +752,26 @@ def general_toposort(r_out, deps, debug_print=False):
return rlist return rlist
def io_toposort(i, o, orderings=None): def io_toposort(inputs, outputs, orderings=None):
"""WRITEME """WRITEME
inputs: a list or tuple of Variable instances
outputs: a list or tuple of Variable instances
orderings: a dictionary
key: Apply instance
value: list of Apply instance
it is important that the value be
a container with a deterministic iteration
order. no sets allowed!
""" """
if orderings is None: if orderings is None:
orderings = {} orderings = {}
#the inputs are used only here in the function that decides what 'predecessors' to explore #the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(i) iset = set(inputs)
def deps(obj): def deps(obj):
rval = [] rval = []
...@@ -747,7 +786,7 @@ def io_toposort(i, o, orderings=None): ...@@ -747,7 +786,7 @@ def io_toposort(i, o, orderings=None):
assert not orderings.get(obj, []) assert not orderings.get(obj, [])
return rval return rval
topo = general_toposort(o, deps) topo = general_toposort(outputs, deps)
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
......
...@@ -162,7 +162,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -162,7 +162,7 @@ class SeqOptimizer(Optimizer, list):
l = [] l = []
if fgraph.profile: if fgraph.profile:
validate_before = fgraph.profile.validate_time validate_before = fgraph.profile.validate_time
nb_node_before = len(fgraph.nodes) nb_node_before = len(fgraph.apply_nodes)
sub_profs = [] sub_profs = []
for optimizer in self: for optimizer in self:
try: try:
...@@ -184,7 +184,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -184,7 +184,7 @@ class SeqOptimizer(Optimizer, list):
print "SeqOptimizer", print "SeqOptimizer",
if hasattr(self,"name"): print self.name, if hasattr(self,"name"): print self.name,
elif hasattr(self,"__name__"): print self.__name__, elif hasattr(self,"__name__"): print self.__name__,
print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(fgraph.nodes)) print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(fgraph.apply_nodes))
print " time %.3fs for validate " % ( print " time %.3fs for validate " % (
fgraph.profile.validate_time - validate_before) fgraph.profile.validate_time - validate_before)
ll=[] ll=[]
...@@ -208,7 +208,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -208,7 +208,7 @@ class SeqOptimizer(Optimizer, list):
else: else:
validate_time = None validate_time = None
return (self, l, validate_time, nb_node_before, return (self, l, validate_time, nb_node_before,
len(fgraph.nodes), sub_profs) len(fgraph.apply_nodes), sub_profs)
def __eq__(self, other): def __eq__(self, other):
#added to override the list's __eq__ implementation #added to override the list's __eq__ implementation
...@@ -1503,7 +1503,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1503,7 +1503,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort = True max_use_abort = True
opt_name = (getattr(lopt, "name", None) opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", "")) or getattr(lopt, "__name__", ""))
if node not in fgraph.nodes: if node not in fgraph.apply_nodes:
# go to next node # go to next node
break break
finally: finally:
......
...@@ -71,9 +71,9 @@ if 0: ...@@ -71,9 +71,9 @@ if 0:
def apply(self, fgraph): def apply(self, fgraph):
tasks = defaultdict(list) tasks = defaultdict(list)
if self.max_use_ratio is not None: if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(fgraph.nodes) max_uses = self.max_use_ratio * len(fgraph.apply_nodes)
runs = defaultdict(int) runs = defaultdict(int)
else: else:
runs = None runs = None
...@@ -91,10 +91,10 @@ if 0: ...@@ -91,10 +91,10 @@ if 0:
self.backtrack(new_r.owner, tasks) self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == # # # == NOT IDEAL == #
# for node in fgraph.nodes: # for node in fgraph.apply_nodes:
# importer(node) # importer(node)
for node in fgraph.toposort(): for node in fgraph.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op)) tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
...@@ -124,7 +124,7 @@ if 0: ...@@ -124,7 +124,7 @@ if 0:
# if isinstance(in1, basestring): # if isinstance(in1, basestring):
# candidate.match[in1] = in2 # candidate.match[in1] = in2
# for client in node.clients: # for client in node.clients:
# op = node.op # op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)]) # patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
......
...@@ -7,6 +7,9 @@ import graph ...@@ -7,6 +7,9 @@ import graph
class AlreadyThere(Exception): class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
feature."""
pass pass
...@@ -32,13 +35,18 @@ class Feature(object): ...@@ -32,13 +35,18 @@ class Feature(object):
def on_attach(self, function_graph): def on_attach(self, function_graph):
""" """
Called by FunctionGraph.attach_feature, the method that attaches the feature Called by FunctionGraph.attach_feature, the method that attaches
to the FunctionGraph. Since this is called after the FunctionGraph the feature to the FunctionGraph. Since this is called after the
is initially populated, this is where you should run checks on the FunctionGraph is initially populated, this is where you should
initial contents of the FunctionGraph. run checks on the initial contents of the FunctionGraph.
The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods The on_attach method may raise the AlreadyThere exception to cancel
to it dynamically. the attach operation if it detects that another Feature instance
implementing the same functionality is already atttached to the
FunctionGraph.
The feature has great freedom in what it can do with the
function_graph: it may, for example, add methods to it dynamically.
""" """
def on_detach(self, function_graph): def on_detach(self, function_graph):
...@@ -219,7 +227,7 @@ class ReplaceValidate(History, Validator): ...@@ -219,7 +227,7 @@ class ReplaceValidate(History, Validator):
""" """
chk = fgraph.replace_all_validate(replacements, reason) chk = fgraph.replace_all_validate(replacements, reason)
for rm in remove: for rm in remove:
if rm in fgraph.nodes or rm in fgraph.variables: if rm in fgraph.apply_nodes or rm in fgraph.variables:
fgraph.revert(chk) fgraph.revert(chk)
if warn: if warn:
out = sys.stderr out = sys.stderr
......
...@@ -1002,7 +1002,7 @@ def test_many_arg_elemwise(): ...@@ -1002,7 +1002,7 @@ def test_many_arg_elemwise():
#assert that the test was done on the gpu. #assert that the test was done on the gpu.
if mode is mode_with_gpu: if mode is mode_with_gpu:
assert any([isinstance(node.op, cuda.GpuElemwise) assert any([isinstance(node.op, cuda.GpuElemwise)
for node in f.maker.fgraph.nodes]) for node in f.maker.fgraph.apply_nodes])
#test the optijmization local_gpu_elemwise_1 #test the optijmization local_gpu_elemwise_1
f = theano.function( f = theano.function(
...@@ -1013,7 +1013,7 @@ def test_many_arg_elemwise(): ...@@ -1013,7 +1013,7 @@ def test_many_arg_elemwise():
#assert that the test was done on the gpu. #assert that the test was done on the gpu.
if mode is mode_with_gpu: if mode is mode_with_gpu:
assert any([isinstance(node.op, cuda.GpuElemwise) assert any([isinstance(node.op, cuda.GpuElemwise)
for node in f.maker.fgraph.nodes]) for node in f.maker.fgraph.apply_nodes])
assert numpy.allclose(out, outputs[-1]) assert numpy.allclose(out, outputs[-1])
results_gpu, results_cpu = outputs results_gpu, results_cpu = outputs
......
...@@ -2667,7 +2667,7 @@ class Composite(ScalarOp): ...@@ -2667,7 +2667,7 @@ class Composite(ScalarOp):
def init_fgraph(self): def init_fgraph(self):
fgraph = FunctionGraph(*gof.graph.clone(self.inputs, self.outputs)) fgraph = FunctionGraph(*gof.graph.clone(self.inputs, self.outputs))
gof.MergeOptimizer().optimize(fgraph) gof.MergeOptimizer().optimize(fgraph)
for node in fgraph.nodes: for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp): if not isinstance(node.op, ScalarOp):
raise ValueError("The fgraph to Composite must be exclusively" raise ValueError("The fgraph to Composite must be exclusively"
" composed of ScalarOp instances.") " composed of ScalarOp instances.")
......
...@@ -1382,7 +1382,7 @@ class GemmOptimizer(Optimizer): ...@@ -1382,7 +1382,7 @@ class GemmOptimizer(Optimizer):
(theano.scalar.Add, theano.scalar.Sub, (theano.scalar.Add, theano.scalar.Sub,
theano.scalar.Neg, theano.scalar.Mul))): theano.scalar.Neg, theano.scalar.Mul))):
continue continue
if not node in fgraph.nodes: if not node in fgraph.apply_nodes:
# This mean that we already removed this node from # This mean that we already removed this node from
# the graph # the graph
continue continue
......
...@@ -176,7 +176,7 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -176,7 +176,7 @@ def inplace_elemwise_optimizer_op(OP):
# We execute `validate` after this number of change. # We execute `validate` after this number of change.
check_each_change = config.tensor.insert_inplace_optimizer_validate_nb check_each_change = config.tensor.insert_inplace_optimizer_validate_nb
if check_each_change == -1: if check_each_change == -1:
if len(fgraph.nodes) > 500: if len(fgraph.apply_nodes) > 500:
check_each_change = 10 check_each_change = 10
else: else:
check_each_change = 1 check_each_change = 1
...@@ -4596,7 +4596,7 @@ class FusionOptimizer(Optimizer): ...@@ -4596,7 +4596,7 @@ class FusionOptimizer(Optimizer):
did_something = False did_something = False
for node in nodelist: for node in nodelist:
# Don't try to fuse node that have already been fused. # Don't try to fuse node that have already been fused.
if node in fgraph.nodes: if node in fgraph.apply_nodes:
new_outputs = self.optimizer(node) new_outputs = self.optimizer(node)
if new_outputs: if new_outputs:
assert len(new_outputs) == len(node.outputs) assert len(new_outputs) == len(node.outputs)
......
...@@ -478,7 +478,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], ...@@ -478,7 +478,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
mode='FAST_RUN', mode='FAST_RUN',
on_unused_input='ignore') on_unused_input='ignore')
nb_gemm = 0 nb_gemm = 0
for node in f.maker.fgraph.nodes: for node in f.maker.fgraph.apply_nodes:
if node.op == T.dot: if node.op == T.dot:
raise Failure('dot not changed to gemm_inplace in graph') raise Failure('dot not changed to gemm_inplace in graph')
if node.op == _dot22: if node.op == _dot22:
...@@ -488,7 +488,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], ...@@ -488,7 +488,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm) assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None), g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True, on_unused_input='ignore') allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.fgraph.nodes: for node in g.maker.fgraph.apply_nodes:
if node.op == gemm_inplace: if node.op == gemm_inplace:
raise Exception('gemm_inplace in original graph') raise Exception('gemm_inplace in original graph')
...@@ -561,14 +561,14 @@ def test_gemm_opt_double_gemm(): ...@@ -561,14 +561,14 @@ def test_gemm_opt_double_gemm():
try: try:
f = inplace_func([Param(ii, mutable=True) for ii in i], o, f = inplace_func([Param(ii, mutable=True) for ii in i], o,
mode='FAST_RUN', on_unused_input='ignore') mode='FAST_RUN', on_unused_input='ignore')
for node in f.maker.fgraph.nodes: for node in f.maker.fgraph.apply_nodes:
if node.op == T.dot: if node.op == T.dot:
raise Failure('dot in graph') raise Failure('dot in graph')
if node.op == _dot22: if node.op == _dot22:
raise Failure('_dot22 in graph') raise Failure('_dot22 in graph')
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None), g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
on_unused_input='ignore') on_unused_input='ignore')
#for node in g.maker.fgraph.nodes: #for node in g.maker.fgraph.apply_nodes:
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph') # if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
...@@ -760,11 +760,11 @@ def test_gemm_opt_vector_stuff(): ...@@ -760,11 +760,11 @@ def test_gemm_opt_vector_stuff():
u, v = T.vector(), T.vector() u, v = T.vector(), T.vector()
f = inplace_func([a, u, v], a + T.dot(u, v), mode='FAST_RUN') f = inplace_func([a, u, v], a + T.dot(u, v), mode='FAST_RUN')
if gemm_inplace in [n.op for n in f.maker.fgraph.nodes]: if gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]:
raise Failure('gemm_inplace in graph') raise Failure('gemm_inplace in graph')
f = inplace_func([a, u, X, Y], a * u + T.dot(X, Y), mode='FAST_RUN') f = inplace_func([a, u, X, Y], a * u + T.dot(X, Y), mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.fgraph.nodes]): if (gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
raise Failure('gemm_inplace in graph') raise Failure('gemm_inplace in graph')
...@@ -823,16 +823,16 @@ def test_inplace0(): ...@@ -823,16 +823,16 @@ def test_inplace0():
f = inplace_func([Z, b, R, S], f = inplace_func([Z, b, R, S],
[Z * (Z + b * T.dot(R, S).T)], mode='FAST_RUN') [Z * (Z + b * T.dot(R, S).T)], mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.fgraph.nodes]): if (gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
print pp(f.maker.fgraph.outputs[0]) print pp(f.maker.fgraph.outputs[0])
raise Failure('gemm_inplace in graph') raise Failure('gemm_inplace in graph')
assert gemm_no_inplace in [n.op for n in f.maker.fgraph.nodes] assert gemm_no_inplace in [n.op for n in f.maker.fgraph.apply_nodes]
# gemm_inplace should be inserted here, to work in-place on Z*c # gemm_inplace should be inserted here, to work in-place on Z*c
f = inplace_func([X, Y, Z, a, b, R, S, c], f = inplace_func([X, Y, Z, a, b, R, S, c],
[Z * (c * Z + a * T.dot(X, Y) + b * T.dot(R, S).T)], [Z * (c * Z + a * T.dot(X, Y) + b * T.dot(R, S).T)],
mode='FAST_RUN') mode='FAST_RUN')
if (not gemm_inplace in [n.op for n in f.maker.fgraph.nodes]): if (not gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
theano.printing.debugprint(f) theano.printing.debugprint(f)
raise Failure('no gemm_inplace in graph') raise Failure('no gemm_inplace in graph')
...@@ -844,7 +844,7 @@ def test_inplace1(): ...@@ -844,7 +844,7 @@ def test_inplace1():
[Z + Z + T.dot(X, Y)], mode='FAST_RUN') [Z + Z + T.dot(X, Y)], mode='FAST_RUN')
#theano.printing.debugprint(f) #theano.printing.debugprint(f)
# it doesn't work inplace because we didn't mark Z as mutable input # it doesn't work inplace because we didn't mark Z as mutable input
assert [n.op for n in f.maker.fgraph.nodes] == [gemm_no_inplace] assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
def test_dot22(): def test_dot22():
......
...@@ -590,7 +590,7 @@ def test_naacl_model(iters_per_unsup=3, iters_per_sup=3, ...@@ -590,7 +590,7 @@ def test_naacl_model(iters_per_unsup=3, iters_per_sup=3,
#print input_pretraining_gradients[4].owner.inputs[1].owner.inputs #print input_pretraining_gradients[4].owner.inputs[1].owner.inputs
#sys.exit() #sys.exit()
#print "PROGRAM LEN %i HASH %i"% (len(m.pretraining_update.maker.fgraph.nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str)) #print "PROGRAM LEN %i HASH %i"% (len(m.pretraining_update.maker.fgraph.apply_nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str))
rng = N.random.RandomState(unittest_tools.fetch_seed(23904)) rng = N.random.RandomState(unittest_tools.fetch_seed(23904))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论