提交 a351e3db authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1482 from nouiz/rnade

Scan crash fix
...@@ -482,6 +482,14 @@ import theano and print the config variable, as in: ...@@ -482,6 +482,14 @@ import theano and print the config variable, as in:
This flag's value cannot be modified during the program execution. This flag's value cannot be modified during the program execution.
.. attribute:: optimizer_verbose
Bool value: either True or False
Default: False
When True, we print on the stdout the optimization applied.
.. attribute:: nocleanup .. attribute:: nocleanup
Bool value: either True or False Bool value: either True or False
...@@ -630,6 +638,12 @@ import theano and print the config variable, as in: ...@@ -630,6 +638,12 @@ import theano and print the config variable, as in:
this Op this Op
- ``'raise'`` will raise an Exception - ``'raise'`` will raise an Exception
.. attribute:: config.compute_test_value_opt
As ``compute_test_value``, but it is the value used during Theano
optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization.
.. attribute:: config.exception_verbosity .. attribute:: config.exception_verbosity
String Value: ``'low'``, ``'high'``. String Value: ``'low'``, ``'high'``.
......
...@@ -1428,21 +1428,25 @@ class _VariableEquivalenceTracker(object): ...@@ -1428,21 +1428,25 @@ class _VariableEquivalenceTracker(object):
self.reasons = {} self.reasons = {}
self.replaced_by = {} self.replaced_by = {}
self.event_list = [] self.event_list = []
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
assert fgraph is self.fgraph assert fgraph is self.fgraph
self.fgraph = None self.fgraph = None
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('prune', node)) self.event_list.append(_FunctionGraphEvent('prune', node,
reason=reason))
#print 'PRUNING NODE', node, id(node) #print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes assert node in self.active_nodes
assert node not in self.inactive_nodes assert node not in self.inactive_nodes
self.active_nodes.remove(node) self.active_nodes.remove(node)
self.inactive_nodes.add(node) self.inactive_nodes.add(node)
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('import', node)) self.event_list.append(_FunctionGraphEvent('import', node,
reason=reason))
#print 'NEW NODE', node, id(node) #print 'NEW NODE', node, id(node)
assert node not in self.active_nodes assert node not in self.active_nodes
...@@ -2114,7 +2118,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2114,7 +2118,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# optimize the fgraph # optimize the fgraph
compute_test_value_orig = theano.config.compute_test_value compute_test_value_orig = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = "off" theano.config.compute_test_value = theano.config.compute_test_value_opt
optimizer(fgraph) optimizer(fgraph)
theano.compile.function_module.insert_deepcopy(fgraph, inputs, theano.compile.function_module.insert_deepcopy(fgraph, inputs,
......
...@@ -1018,7 +1018,7 @@ class FunctionMaker(object): ...@@ -1018,7 +1018,7 @@ class FunctionMaker(object):
compute_test_value_orig = theano.config.compute_test_value compute_test_value_orig = theano.config.compute_test_value
add_stack_trace_on_call = gof.Op.add_stack_trace_on_call add_stack_trace_on_call = gof.Op.add_stack_trace_on_call
try: try:
theano.config.compute_test_value = "off" theano.config.compute_test_value = theano.config.compute_test_value_opt
gof.Op.add_stack_trace_on_call = False gof.Op.add_stack_trace_on_call = False
start_optimizer = time.time() start_optimizer = time.time()
optimizer_profile = optimizer(fgraph) optimizer_profile = optimizer(fgraph)
......
...@@ -157,6 +157,11 @@ AddConfigVar('optimizer', ...@@ -157,6 +157,11 @@ AddConfigVar('optimizer',
EnumStr('fast_run', 'merge', 'fast_compile', 'None'), EnumStr('fast_run', 'merge', 'fast_compile', 'None'),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_verbose',
"If True, we print all optimization being applied",
BoolParam(False),
in_c_key=False)
AddConfigVar('on_opt_error', AddConfigVar('on_opt_error',
("What to do when an optimization crashes: warn and skip it, raise " ("What to do when an optimization crashes: warn and skip it, raise "
"the exception, or fall into the pdb debugger."), "the exception, or fall into the pdb debugger."),
...@@ -379,10 +384,17 @@ AddConfigVar('compute_test_value', ...@@ -379,10 +384,17 @@ AddConfigVar('compute_test_value',
"Constants, SharedVariables and the tag 'test_value' as inputs " "Constants, SharedVariables and the tag 'test_value' as inputs "
"to the function. This helps the user track down problems in the " "to the function. This helps the user track down problems in the "
"graph before it gets optimized."), "graph before it gets optimized."),
EnumStr('off', 'ignore', 'warn', 'raise'), EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False) in_c_key=False)
AddConfigVar('compute_test_value_opt',
("For debugging Theano optimization only."
" Same as compute_test_value, but is used"
" during Theano optimization"),
EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False)
"""Note to developers: """Note to developers:
Generally your exceptions should use an apply node's __str__ Generally your exceptions should use an apply node's __str__
method when exception_verbosity == 'low'. When exception_verbosity method when exception_verbosity == 'low'. When exception_verbosity
......
...@@ -380,7 +380,7 @@ if 0: ...@@ -380,7 +380,7 @@ if 0:
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def on_import(self, fgraph, app): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import") #if app in self.debug_all_apps: raise ProtocolError("double import")
...@@ -410,7 +410,7 @@ if 0: ...@@ -410,7 +410,7 @@ if 0:
self.stale_droot = True self.stale_droot = True
def on_prune(self, fgraph, app): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import") #if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app) #self.debug_all_apps.remove(app)
...@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper):
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def on_import(self, fgraph, app): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import") if app in self.debug_all_apps: raise ProtocolError("double import")
...@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self.stale_droot = True self.stale_droot = True
def on_prune(self, fgraph, app): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import") if app not in self.debug_all_apps: raise ProtocolError("prune without import")
self.debug_all_apps.remove(app) self.debug_all_apps.remove(app)
......
...@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception ...@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
types that it can raise types that it can raise
""" """
import sys import sys
from theano.gof import graph from theano.gof import graph
from theano.gof import utils from theano.gof import utils
from theano.gof import toolbox from theano.gof import toolbox
...@@ -16,6 +17,7 @@ NullType = None ...@@ -16,6 +17,7 @@ NullType = None
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
This exception should be thrown by listeners to FunctionGraph when the This exception should be thrown by listeners to FunctionGraph when the
...@@ -82,7 +84,8 @@ class FunctionGraph(utils.object2): ...@@ -82,7 +84,8 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set. # so I probably am) this should be a set.
self._features = [] self._features = []
# All apply nodes in the subgraph defined by inputs and outputs are cached in this field # All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self.apply_nodes = set() self.apply_nodes = set()
# Ditto for variable nodes # Ditto for variable nodes
...@@ -104,7 +107,7 @@ class FunctionGraph(utils.object2): ...@@ -104,7 +107,7 @@ class FunctionGraph(utils.object2):
self.__setup_r__(input) self.__setup_r__(input)
self.variables.add(input) self.variables.add(input)
self.__import_r__(outputs) self.__import_r__(outputs, reason="init")
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
output.clients.append(('output', i)) output.clients.append(('output', i))
...@@ -112,12 +115,12 @@ class FunctionGraph(utils.object2): ...@@ -112,12 +115,12 @@ class FunctionGraph(utils.object2):
self.variable_locks = {} self.variable_locks = {}
self.profile = None self.profile = None
### Setup a Variable ### ### Setup a Variable ###
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this fgraph # sets up r so it belongs to this fgraph
if hasattr(r, 'fgraph') and r.fgraph is not None and r.fgraph is not self: if (hasattr(r, 'fgraph') and
r.fgraph is not None and
r.fgraph is not self):
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self r.fgraph = self
r.clients = [] r.clients = []
...@@ -165,13 +168,13 @@ class FunctionGraph(utils.object2): ...@@ -165,13 +168,13 @@ class FunctionGraph(utils.object2):
self.inputs = None self.inputs = None
self.outputs = None self.outputs = None
### clients ### ### clients ###
def clients(self, r): def clients(self, r):
""" """
Set of all the (node, i) pairs such that node.inputs[i] is r. Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have r as input at index i. Tell differently, a list of (node,i) such that each node have
r as input at index i.
""" """
return r.clients return r.clients
...@@ -184,12 +187,15 @@ class FunctionGraph(utils.object2): ...@@ -184,12 +187,15 @@ class FunctionGraph(utils.object2):
""" """
if set(r.clients).intersection(set(new_clients)): if set(r.clients).intersection(set(new_clients)):
print >> sys.stderr, 'ERROR: clients intersect!' print >> sys.stderr, 'ERROR: clients intersect!'
print >> sys.stderr, ' RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients] print >> sys.stderr, ' RCLIENTS of', r, [(n, i, type(n), id(n))
print >> sys.stderr, ' NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients] for n, i in r.clients]
print >> sys.stderr, ' NCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in new_clients]
assert not set(r.clients).intersection(set(new_clients)) assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients r.clients += new_clients
def __remove_clients__(self, r, clients_to_remove, prune = True): def __remove_clients__(self, r, clients_to_remove,
prune=True, reason=None):
""" WRITEME """ WRITEME
r -> variable r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore. clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
...@@ -205,15 +211,13 @@ class FunctionGraph(utils.object2): ...@@ -205,15 +211,13 @@ class FunctionGraph(utils.object2):
assert entry not in r.clients # an op,i pair should be unique assert entry not in r.clients # an op,i pair should be unique
if not r.clients: if not r.clients:
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r], reason)
return False return False
return True return True
return False return False
### import ### ### import ###
def __import_r__(self, variables, reason):
def __import_r__(self, variables):
global NullType global NullType
if NullType is None: if NullType is None:
from null_type import NullType from null_type import NullType
...@@ -222,17 +226,18 @@ class FunctionGraph(utils.object2): ...@@ -222,17 +226,18 @@ class FunctionGraph(utils.object2):
for apply_node in [r.owner for r in variables if r.owner is not None]: for apply_node in [r.owner for r in variables if r.owner is not None]:
if apply_node not in r_owner_done: if apply_node not in r_owner_done:
r_owner_done.add(apply_node) r_owner_done.add(apply_node)
self.__import__(apply_node) self.__import__(apply_node, reason=reason)
for r in variables: for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type,NullType): if isinstance(r.type, NullType):
raise TypeError("Computation graph contains a NaN. "+r.type.why_null) raise TypeError("Computation graph contains a NaN. " +
r.type.why_null)
raise MissingInputError("Undeclared input", r) raise MissingInputError("Undeclared input", r)
if not getattr(r, 'fgraph', None) is self: if not getattr(r, 'fgraph', None) is self:
self.__setup_r__(r) self.__setup_r__(r)
self.variables.add(r) self.variables.add(r)
def __import__(self, apply_node, check = True): def __import__(self, apply_node, check=True, reason=None):
node = apply_node node = apply_node
# We import the nodes in topological order. We only are interested # We import the nodes in topological order. We only are interested
...@@ -248,7 +253,9 @@ class FunctionGraph(utils.object2): ...@@ -248,7 +253,9 @@ class FunctionGraph(utils.object2):
for r in node.inputs: for r in node.inputs:
if hasattr(r, 'fgraph') and r.fgraph is not self: if hasattr(r, 'fgraph') and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: if (r.owner is None and
not isinstance(r, graph.Constant) and
r not in self.inputs):
#Verbose error message #Verbose error message
#Show a complete chain of variables from the missing input to an output #Show a complete chain of variables from the missing input to an output
...@@ -328,20 +335,18 @@ class FunctionGraph(utils.object2): ...@@ -328,20 +335,18 @@ class FunctionGraph(utils.object2):
self.variables.add(input) self.variables.add(input)
self.__add_clients__(input, [(node, i)]) self.__add_clients__(input, [(node, i)])
assert node.fgraph is self assert node.fgraph is self
self.execute_callbacks('on_import', node) self.execute_callbacks('on_import', node, reason)
### prune ### ### prune ###
def __prune_r__(self, variables, reason=None):
def __prune_r__(self, variables):
# Prunes the owners of the variables. # Prunes the owners of the variables.
for node in set(r.owner for r in variables if r.owner is not None): for node in set(r.owner for r in variables if r.owner is not None):
self.__prune__(node) self.__prune__(node, reason)
for r in variables: for r in variables:
if not r.clients and r in self.variables: if not r.clients and r in self.variables:
self.variables.remove(r) self.variables.remove(r)
def __prune__(self, apply_node): def __prune__(self, apply_node, reason=None):
node = apply_node node = apply_node
if node not in self.apply_nodes: if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node) raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node)
...@@ -356,16 +361,13 @@ class FunctionGraph(utils.object2): ...@@ -356,16 +361,13 @@ class FunctionGraph(utils.object2):
return return
self.apply_nodes.remove(node) self.apply_nodes.remove(node)
self.variables.difference_update(node.outputs) self.variables.difference_update(node.outputs)
self.execute_callbacks('on_prune', node) self.execute_callbacks('on_prune', node, reason)
for i, input in enumerate(node.inputs): for i, input in enumerate(node.inputs):
self.__remove_clients__(input, [(node, i)]) self.__remove_clients__(input, [(node, i)], reason=reason)
#self.__prune_r__(node.inputs) #self.__prune_r__(node.inputs)
### change input ### ### change input ###
def change_input(self, node, i, new_r, reason=None): def change_input(self, node, i, new_r, reason=None):
"""WRITEME """WRITEME
Changes node.inputs[i] to new_r. Changes node.inputs[i] to new_r.
...@@ -398,25 +400,28 @@ class FunctionGraph(utils.object2): ...@@ -398,25 +400,28 @@ class FunctionGraph(utils.object2):
if r is new_r: if r is new_r:
return return
self.__import_r__([new_r]) self.__import_r__([new_r], reason=reason)
self.__add_clients__(new_r, [(node, i)]) self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False) prune = self.__remove_clients__(r, [(node, i)], False)
# Precondition: the substitution is semantically valid # Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the # However it may introduce cycles to the graph, in which case the
# transaction will be reverted later. # transaction will be reverted later.
self.execute_callbacks('on_change_input', node, i, r, new_r, reason=reason) self.execute_callbacks('on_change_input', node, i,
r, new_r, reason=reason)
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r], reason=reason)
### replace ### ### replace ###
def replace(self, r, new_r, reason=None, verbose=None):
def replace(self, r, new_r, reason=None):
""" WRITEME """ WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph. This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead. For every node that uses r as input, makes it use new_r instead.
""" """
if verbose is None:
verbose = config.optimizer_verbose
if verbose:
print reason, r, new_r
if r.fgraph is not self: if r.fgraph is not self:
raise Exception("Cannot replace %s because it does not belong to this FunctionGraph" % r, str(reason)) raise Exception("Cannot replace %s because it does not belong to this FunctionGraph" % r, str(reason))
if not r.type == new_r.type: if not r.type == new_r.type:
...@@ -440,8 +445,6 @@ class FunctionGraph(utils.object2): ...@@ -440,8 +445,6 @@ class FunctionGraph(utils.object2):
for r, new_r in pairs: for r, new_r in pairs:
self.replace(r, new_r, reason=reason) self.replace(r, new_r, reason=reason)
def extend(self, feature): def extend(self, feature):
warnings.warn("FunctionGraph.extend is deprecatd. It has been " warnings.warn("FunctionGraph.extend is deprecatd. It has been "
"renamed to FunctionGraph.attach_feature") "renamed to FunctionGraph.attach_feature")
...@@ -481,7 +484,9 @@ class FunctionGraph(utils.object2): ...@@ -481,7 +484,9 @@ class FunctionGraph(utils.object2):
"""WRITEME """WRITEME
Removes the feature from the graph. Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method is defined. Calls feature.on_detach(function_graph) if an on_detach method
is defined.
""" """
try: try:
self._features.remove(feature) self._features.remove(feature)
...@@ -491,9 +496,7 @@ class FunctionGraph(utils.object2): ...@@ -491,9 +496,7 @@ class FunctionGraph(utils.object2):
if detach is not None: if detach is not None:
detach(self) detach(self)
### callback utils ### ### callback utils ###
def execute_callbacks(self, name, *args, **kwargs): def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME """WRITEME
Calls Calls
...@@ -518,7 +521,6 @@ class FunctionGraph(utils.object2): ...@@ -518,7 +521,6 @@ class FunctionGraph(utils.object2):
else: else:
raise raise
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
"""WRITEME """WRITEME
Returns a dictionary d such that: Returns a dictionary d such that:
...@@ -534,9 +536,7 @@ class FunctionGraph(utils.object2): ...@@ -534,9 +536,7 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args) d[feature] = fn(*args)
return d return d
### misc ### ### misc ###
def toposort(self): def toposort(self):
"""WRITEME """WRITEME
Returns an ordering of the graph's Apply nodes such that: Returns an ordering of the graph's Apply nodes such that:
...@@ -552,8 +552,8 @@ class FunctionGraph(utils.object2): ...@@ -552,8 +552,8 @@ class FunctionGraph(utils.object2):
if len(self.apply_nodes) < 2: if len(self.apply_nodes) < 2:
# optimization # optimization
# when there are 0 or 1 nodes, no sorting is necessary # when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker produces # This special case happens a lot because the OpWiseCLinker
# 1-element graphs. # produces 1-element graphs.
return list(self.apply_nodes) return list(self.apply_nodes)
fg = self fg = self
...@@ -568,8 +568,9 @@ class FunctionGraph(utils.object2): ...@@ -568,8 +568,9 @@ class FunctionGraph(utils.object2):
Return dict d s.t. d[node] is a list of nodes that must be evaluated Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated. before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that all This is used primarily by the destroy_handler feature to ensure that
clients of any destroyed inputs have already computed their outputs. all clients of any destroyed inputs have already computed their
outputs.
:note: This only calls the orderings() fct on all features. It does not :note: This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself. take care of computing dependencies by itself.
...@@ -581,17 +582,19 @@ class FunctionGraph(utils.object2): ...@@ -581,17 +582,19 @@ class FunctionGraph(utils.object2):
if hasattr(feature, 'orderings'): if hasattr(feature, 'orderings'):
orderings = feature.orderings(self) orderings = feature.orderings(self)
if not isinstance(orderings, OrderedDict): if not isinstance(orderings, OrderedDict):
raise TypeError("Non-deterministic return value from " \ raise TypeError("Non-deterministic return value from " +
+str(feature.orderings) \ str(feature.orderings) +
+". Nondeterministic object is "+str(orderings)) ". Nondeterministic object is " +
str(orderings))
for node, prereqs in orderings.items(): for node, prereqs in orderings.items():
if not isinstance(prereqs, (list, OrderedSet)): if not isinstance(prereqs, (list, OrderedSet)):
raise TypeError("prereqs must be a type with a " raise TypeError(
"prereqs must be a type with a "
"deterministic iteration order, or toposort " "deterministic iteration order, or toposort "
" will be non-deterministic.") " will be non-deterministic.")
ords.setdefault(node, []).extend(prereqs) ords.setdefault(node, []).extend(prereqs)
# eliminate duplicate prereqs # eliminate duplicate prereqs
for (node,prereqs) in ords.items(): for (node, prereqs) in ords.items():
ords[node] = list(OrderedSet(prereqs)) ords[node] = list(OrderedSet(prereqs))
return ords return ords
...@@ -624,34 +627,48 @@ class FunctionGraph(utils.object2): ...@@ -624,34 +627,48 @@ class FunctionGraph(utils.object2):
if self.apply_nodes != nodes: if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes) missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes) excess = self.apply_nodes.difference(nodes)
raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess) raise Exception(
"The nodes are inappropriately cached. missing, in excess: ",
missing, excess)
for node in nodes: for node in nodes:
if node.fgraph is not self: if node.fgraph is not self:
raise Exception("Node should belong to the FunctionGraph.", node) raise Exception("Node should belong to the FunctionGraph.",
node)
for i, variable in enumerate(node.inputs): for i, variable in enumerate(node.inputs):
if variable.fgraph is not self: if variable.fgraph is not self:
raise Exception("Input of node should belong to the FunctionGraph.", variable, (node, i)) raise Exception(
"Input of node should belong to the FunctionGraph.",
variable, (node, i))
if (node, i) not in variable.clients: if (node, i) not in variable.clients:
raise Exception("Inconsistent clients list.", (node, i), variable.clients) raise Exception("Inconsistent clients list.",
(node, i), variable.clients)
variables = set(graph.variables(self.inputs, self.outputs)) variables = set(graph.variables(self.inputs, self.outputs))
if set(self.variables) != variables: if set(self.variables) != variables:
missing = variables.difference(self.variables) missing = variables.difference(self.variables)
excess = self.variables.difference(variables) excess = self.variables.difference(variables)
raise Exception("The variables are inappropriately cached. missing, in excess: ", missing, excess) raise Exception(
"The variables are inappropriately cached. missing, in excess: ",
missing, excess)
for variable in variables: for variable in variables:
if variable.owner is None and variable not in self.inputs and not isinstance(variable, graph.Constant): if (variable.owner is None and
variable not in self.inputs and
not isinstance(variable, graph.Constant)):
raise Exception("Undeclared input.", variable) raise Exception("Undeclared input.", variable)
if variable.fgraph is not self: if variable.fgraph is not self:
raise Exception("Variable should belong to the FunctionGraph.", variable) raise Exception("Variable should belong to the FunctionGraph.",
variable)
for node, i in variable.clients: for node, i in variable.clients:
if node == 'output': if node == 'output':
if self.outputs[i] is not variable: if self.outputs[i] is not variable:
raise Exception("Inconsistent clients list.", variable, self.outputs[i]) raise Exception("Inconsistent clients list.",
variable, self.outputs[i])
continue continue
if node not in nodes: if node not in nodes:
raise Exception("Client not in FunctionGraph.", variable, (node, i)) raise Exception("Client not in FunctionGraph.",
variable, (node, i))
if node.inputs[i] is not variable: if node.inputs[i] is not variable:
raise Exception("Inconsistent clients list.", variable, node.inputs[i]) raise Exception("Inconsistent clients list.",
variable, node.inputs[i])
def __str__(self): def __str__(self):
return "[%s]" % ", ".join(graph.as_string(self.inputs, self.outputs)) return "[%s]" % ", ".join(graph.as_string(self.inputs, self.outputs))
...@@ -659,9 +676,7 @@ class FunctionGraph(utils.object2): ...@@ -659,9 +676,7 @@ class FunctionGraph(utils.object2):
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
### clone ### ### clone ###
def clone(self): def clone(self):
"""WRITEME""" """WRITEME"""
return self.clone_get_equiv()[0] return self.clone_get_equiv()[0]
......
...@@ -13,6 +13,7 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>" ...@@ -13,6 +13,7 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import logging import logging
import sys
import warnings import warnings
import theano import theano
...@@ -408,6 +409,9 @@ class PureOp(object): ...@@ -408,6 +409,9 @@ class PureOp(object):
elif config.compute_test_value == 'ignore': elif config.compute_test_value == 'ignore':
# silently skip test # silently skip test
run_perform = False run_perform = False
elif config.compute_test_value == 'pdb':
import pdb
pdb.post_mortem(sys.exc_info()[2])
else: else:
raise ValueError('%s is invalid for option config.compute_Test_value' % config.compute_test_value) raise ValueError('%s is invalid for option config.compute_Test_value' % config.compute_test_value)
...@@ -638,8 +642,11 @@ def get_test_value(v): ...@@ -638,8 +642,11 @@ def get_test_value(v):
For a Shared variable, it is the internal value. For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value. For another Variable, it is the content of v.tag.test_value.
""" """
v_tensor = theano.tensor.as_tensor_variable(v) if not isinstance(v, graph.Variable):
return PureOp._get_test_value(v_tensor) v_var = theano.tensor.as_tensor_variable(v)
else:
v_var = v
return PureOp._get_test_value(v_var)
def missing_test_message(msg): def missing_test_message(msg):
......
...@@ -421,7 +421,7 @@ class MergeFeature(object): ...@@ -421,7 +421,7 @@ class MergeFeature(object):
self.blacklist = [] self.blacklist = []
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node) self.on_import(fgraph, node, "on_attach")
def on_change_input(self, fgraph, node, i, r, new_r): def on_change_input(self, fgraph, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct # If inputs to node change, it is not guaranteed that it is distinct
...@@ -433,14 +433,14 @@ class MergeFeature(object): ...@@ -433,14 +433,14 @@ class MergeFeature(object):
if isinstance(new_r, graph.Constant): if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r) self.process_constant(fgraph, new_r)
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
for c in node.inputs: for c in node.inputs:
if isinstance(c, graph.Constant): if isinstance(c, graph.Constant):
self.process_constant(fgraph, c) self.process_constant(fgraph, c)
self.process_node(fgraph, node) self.process_node(fgraph, node)
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node) self.nodes_seen.discard(node)
for c in node.inputs: for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1): if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
...@@ -1165,10 +1165,10 @@ class NavigatorOptimizer(Optimizer): ...@@ -1165,10 +1165,10 @@ class NavigatorOptimizer(Optimizer):
class Updater: class Updater:
if importer is not None: if importer is not None:
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
importer(node) importer(node)
if pruner is not None: if pruner is not None:
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
pruner(node) pruner(node)
if chin is not None: if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r): def on_change_input(self, fgraph, node, i, r, new_r):
...@@ -1357,7 +1357,7 @@ class ChangeTracker: ...@@ -1357,7 +1357,7 @@ class ChangeTracker:
def __init__(self): def __init__(self):
self.changed = False self.changed = False
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
self.changed = True self.changed = True
def on_change_input(self, fgraph, node, i, r, new_r): def on_change_input(self, fgraph, node, i, r, new_r):
......
import sys import sys
import time import time
from theano import config
from theano.gof.python25 import partial from theano.gof.python25 import partial
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
from theano.gof import graph from theano.gof import graph
class AlreadyThere(Exception): class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph """Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical attempting to attach the feature already has a functionally identical
...@@ -57,7 +56,7 @@ class Feature(object): ...@@ -57,7 +56,7 @@ class Feature(object):
functionality that it installed into the function_graph. functionality that it installed into the function_graph.
""" """
def on_import(self, function_graph, node): def on_import(self, function_graph, node, reason):
""" """
Called whenever a node is imported into function_graph, which is Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph. just before the node is actually connected to the graph.
...@@ -66,7 +65,7 @@ class Feature(object): ...@@ -66,7 +65,7 @@ class Feature(object):
you should do this by implementing on_attach. you should do this by implementing on_attach.
""" """
def on_prune(self, function_graph, node): def on_prune(self, function_graph, node, reason):
""" """
Called whenever a node is pruned (removed) from the function_graph, Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph. after it is disconnected from the graph.
...@@ -98,11 +97,11 @@ class Bookkeeper(Feature): ...@@ -98,11 +97,11 @@ class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_import(fgraph, node) self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_prune(fgraph, node) self.on_prune(fgraph, node, 'Bookkeeper.detach')
class History(Feature): class History(Feature):
...@@ -199,11 +198,14 @@ class ReplaceValidate(History, Validator): ...@@ -199,11 +198,14 @@ class ReplaceValidate(History, Validator):
def replace_validate(self, fgraph, r, new_r, reason=None): def replace_validate(self, fgraph, r, new_r, reason=None):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason) self.replace_all_validate(fgraph, [(r, new_r)], reason=reason)
def replace_all_validate(self, fgraph, replacements, reason=None): def replace_all_validate(self, fgraph, replacements,
reason=None, verbose=None):
chk = fgraph.checkpoint() chk = fgraph.checkpoint()
if verbose is None:
verbose = config.optimizer_verbose
for r, new_r in replacements: for r, new_r in replacements:
try: try:
fgraph.replace(r, new_r, reason=reason) fgraph.replace(r, new_r, reason=reason, verbose=False)
except Exception, e: except Exception, e:
if ('The type of the replacement must be the same' not in if ('The type of the replacement must be the same' not in
str(e) and 'does not belong to this FunctionGraph' not in str(e)): str(e) and 'does not belong to this FunctionGraph' not in str(e)):
...@@ -219,6 +221,8 @@ class ReplaceValidate(History, Validator): ...@@ -219,6 +221,8 @@ class ReplaceValidate(History, Validator):
except Exception, e: except Exception, e:
fgraph.revert(chk) fgraph.revert(chk)
raise raise
if verbose:
print reason, r, new_r
return chk return chk
def replace_all_validate_remove(self, fgraph, replacements, def replace_all_validate_remove(self, fgraph, replacements,
...@@ -267,7 +271,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -267,7 +271,7 @@ class NodeFinder(dict, Bookkeeper):
del fgraph.get_nodes del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph) Bookkeeper.on_detach(self, fgraph)
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
try: try:
self.setdefault(node.op, []).append(node) self.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
...@@ -280,7 +284,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -280,7 +284,7 @@ class NodeFinder(dict, Bookkeeper):
print >> sys.stderr, 'OFFENDING node not hashable' print >> sys.stderr, 'OFFENDING node not hashable'
raise e raise e
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
try: try:
nodes = self[node.op] nodes = self[node.op]
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
...@@ -312,13 +316,13 @@ class PrintListener(Feature): ...@@ -312,13 +316,13 @@ class PrintListener(Feature):
if self.active: if self.active:
print "-- detaching from: ", fgraph print "-- detaching from: ", fgraph
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
if self.active: if self.active:
print "-- importing: %s" % node print "-- importing: %s, reason: %s" % (node, reason)
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
if self.active: if self.active:
print "-- pruning: %s" % node print "-- pruning: %s, reason: %s" % (node, reason)
def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active: if self.active:
......
...@@ -2953,7 +2953,6 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2953,7 +2953,6 @@ class GpuJoin(tensor.Join, GpuOp):
axis = inputs[0] axis = inputs[0]
n_cndas = len(inputs[1:]) n_cndas = len(inputs[1:])
input_1 = inputs[1] input_1 = inputs[1]
axis = inputs[0]
fail = sub['fail'] fail = sub['fail']
out = out_[0] out = out_[0]
......
...@@ -137,9 +137,9 @@ class HintsFeature(object): ...@@ -137,9 +137,9 @@ class HintsFeature(object):
# Variable -> tuple(scalars) or None (All tensor vars map to tuple) # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.hints = {} self.hints = {}
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node) self.on_import(fgraph, node, "on_attach")
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.hints: if node.outputs[0] in self.hints:
# this is a revert, not really an import # this is a revert, not really an import
for r in node.outputs + node.inputs: for r in node.outputs + node.inputs:
......
...@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph # shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately # It will call infer_shape and set_shape appropriately
dummy_fgraph = None dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner) shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = [] ret = []
for o in outs: for o in outs:
......
...@@ -183,6 +183,24 @@ class Scalar(Type): ...@@ -183,6 +183,24 @@ class Scalar(Type):
def dtype_specs(self): def dtype_specs(self):
try: try:
# To help debug dtype/typenum problem, here is code to get
# the list of numpy typenum. This list change between 32
# and 64 bit platform and probably also also between
# Windows and Linux.
# NOTE: equivalent type on a platform can have different typenum.
# This is the source of all dtype/typenum problem found up to
# now, as Theano always expect the exact typenum that
# correspond to our supported dtype.
"""
for dtype in ['int8', 'uint8', 'short', 'ushort', 'intc', 'uintc',
'longlong', 'ulonglong', 'single', 'double',
'longdouble', 'csingle', 'cdouble', 'clongdouble',
'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'float', 'double',
'int', 'uint']:
print dtype, np.zeros(1, dtype=dtype).dtype.num
"""
return { # dtype: (py_type, c_type, cls_name) return { # dtype: (py_type, c_type, cls_name)
'float32': (numpy.float32, 'npy_float32', 'Float32'), 'float32': (numpy.float32, 'npy_float32', 'Float32'),
'float64': (numpy.float64, 'npy_float64', 'Float64'), 'float64': (numpy.float64, 'npy_float64', 'Float64'),
......
...@@ -101,7 +101,7 @@ def scan(fn, ...@@ -101,7 +101,7 @@ def scan(fn,
The order of the sequences is the same as the one in the list The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the same `sequences` given to scan. The order of the outputs is the same
as the order of ``output_info``. For any sequence or output the as the order of ``outputs_info``. For any sequence or output the
order of the time slices is the same as the one in which they have order of the time slices is the same as the one in which they have
been given as taps. For example if one writes the following : been given as taps. For example if one writes the following :
...@@ -262,7 +262,7 @@ def scan(fn, ...@@ -262,7 +262,7 @@ def scan(fn,
outputs will have *0 rows*. If the value is negative, ``scan`` outputs will have *0 rows*. If the value is negative, ``scan``
will run backwards in time. If the ``go_backwards`` flag is already will run backwards in time. If the ``go_backwards`` flag is already
set and also ``n_steps`` is negative, ``scan`` will run forward set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n stpes is not provided, ``scan`` will figure in time. If n_steps is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences. out the amount of steps it should run given its input sequences.
...@@ -817,7 +817,7 @@ def scan(fn, ...@@ -817,7 +817,7 @@ def scan(fn,
if as_while: if as_while:
tmp_dummy_f_outs -= 1 tmp_dummy_f_outs -= 1
if not (tmp_dummy_f_outs == n_outs or outs_info == []): if not (tmp_dummy_f_outs == n_outs or outs_info == []):
raise ValueError('Please provide None as output_info for ' raise ValueError('Please provide None as outputs_info for '
'any output that does not feed back into ' 'any output that does not feed back into '
'scan (i.e. it behaves like a map) ') 'scan (i.e. it behaves like a map) ')
......
...@@ -1581,8 +1581,30 @@ class Scan(PureOp): ...@@ -1581,8 +1581,30 @@ class Scan(PureOp):
if not isinstance(x.type, DisconnectedType): if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[::-1]) outer_inp_seqs.append(x[::-1])
outer_inp_seqs += [x[::-1] for x in self.outer_mitsot_outs(outs)] if hasattr(inputs[0].tag, 'test_value'):
outer_inp_seqs += [x[::-1] for x in self.outer_sitsot_outs(outs)] # Here we tests that the new scan input sequence all have
# the same shape[0]. This is a properties that the scan()
# fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test
# that.
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs)):
mintap = numpy.min(taps)
if hasattr(x[::-1][:mintap], 'test_value'):
assert (x[::-1][:mintap].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'):
assert (x[::-1][:-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, 'test_value'):
assert (x[::-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
outer_inp_seqs += [x[::-1][:numpy.min(taps)]
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs))]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
inner_inp_seqs = self.inner_seqs(self_inputs) inner_inp_seqs = self.inner_seqs(self_inputs)
......
...@@ -144,7 +144,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -144,7 +144,7 @@ def remove_constants_and_unused_inputs_scan(node):
nw_info['n_seqs'] = nw_n_seqs nw_info['n_seqs'] = nw_n_seqs
# DEBUG CHECK # DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan.make_node(*nw_outer).outputs nw_outs = nwScan(*nw_outer, **dict(return_list=True))
return nw_outs return nw_outs
else: else:
return False return False
...@@ -227,7 +227,11 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -227,7 +227,11 @@ class PushOutNonSeqScan(gof.Optimizer):
'this on theano-users list'), x) 'this on theano-users list'), x)
outside_ins = [x.type.filter_variable(y) for x, y in outside_ins = [x.type.filter_variable(y) for x, y in
zip(nd.inputs, outside_ins)] zip(nd.inputs, outside_ins)]
nw_outer_node = nd.op.make_node(*outside_ins)
# Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins,
**dict(return_list=True))[0].owner
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
...@@ -285,11 +289,15 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -285,11 +289,15 @@ class PushOutNonSeqScan(gof.Optimizer):
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs) op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node # Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info) nwScan = scan_op.Scan(op_ins, op_outs, op.info)
nw_node = nwScan.make_node(* (node.inputs + nw_outer))
# Do not call make_node for test_value
nw_node = nwScan(*(node.inputs + nw_outer),
**dict(return_list=True))[0].owner
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
zip(node.outputs, nw_node.outputs), zip(node.outputs, nw_node.outputs),
remove=[node], remove=[node],
reason='scan_push_computation_out') reason='scanOp_pushout_nonseqs_ops')
return True return True
elif to_keep == []: elif to_keep == []:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
...@@ -310,7 +318,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -310,7 +318,7 @@ class PushOutNonSeqScan(gof.Optimizer):
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
replace_with.items(), replace_with.items(),
remove=[node], remove=[node],
reason='scan_push_computation_out') reason='scanOp_pushout_nonseqs_ops')
else: else:
return False return False
...@@ -327,8 +335,8 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -327,8 +335,8 @@ class PushOutSeqScan(gof.Optimizer):
fgraph.attach_feature(gof.toolbox.ReplaceValidate()) fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def apply(self, fgraph): def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op, nodelist = [x for x in fgraph.toposort()
scan_op.Scan)] if isinstance(x.op, scan_op.Scan)]
for node in nodelist: for node in nodelist:
self.process_node(fgraph, node) self.process_node(fgraph, node)
...@@ -376,18 +384,21 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -376,18 +384,21 @@ class PushOutSeqScan(gof.Optimizer):
elif x in inner_seqs: elif x in inner_seqs:
outside_ins += [outer_seqs[inner_seqs.index(x)]] outside_ins += [outer_seqs[inner_seqs.index(x)]]
elif x in to_replace: elif x in to_replace:
outside_ins += [replace_with_out[\ outside_ins += [replace_with_out[
to_replace.index(x)]] to_replace.index(x)]]
elif isinstance(x, theano.Constant): elif isinstance(x, theano.Constant):
outside_ins += [x.clone()] outside_ins += [x.clone()]
else: else:
raise Exception( raise Exception(
('Error in the `scan_pushout_non_seq_' ('Error in the `scan_pushout_seq_'
'operations`. The optimization tries ' 'operations`. The optimization tries '
'to move some computation fron scan ' 'to move some computation fron scan '
'which is not allowed to move. Report ' 'which is not allowed to move. Report '
'this on theano-users list'), x) 'this on theano-users list'), x)
nw_outer_node = nd.op.make_node(*outside_ins) # Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins,
**dict(return_list=True))[0].owner
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
...@@ -420,10 +431,15 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -420,10 +431,15 @@ class PushOutSeqScan(gof.Optimizer):
to_replace += [y] to_replace += [y]
replace_with_in += [y_place_holder] replace_with_in += [y_place_holder]
replace_with_out += [new_outer] replace_with_out += [new_outer]
if hasattr(new_outer.tag, "test_value"):
new_sh = new_outer.tag.test_value.shape
ref_sh = (outside_ins.tag.test_value.shape[0],)
ref_sh += nd.outputs[0].tag.test_value.shape
assert new_sh == ref_sh
changed = True changed = True
if counts >= max_iterations: if counts >= max_iterations:
raise Exception('Error in the `scan_pushout_non_seq_operations`.' raise Exception('Error in the `scan_pushout_seq_operations`.'
' The optimization exhausted the maximal number ' ' The optimization exhausted the maximal number '
'of iterations allowed!') 'of iterations allowed!')
# We need to check all candidate replacements and choose those that # We need to check all candidate replacements and choose those that
...@@ -473,12 +489,14 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -473,12 +489,14 @@ class PushOutSeqScan(gof.Optimizer):
nw_info = op.info.copy() nw_info = op.info.copy()
nw_info['n_seqs'] += len(nw_inner) nw_info['n_seqs'] += len(nw_inner)
nwScan = scan_op.Scan(op_ins, op_outs, nw_info) nwScan = scan_op.Scan(op_ins, op_outs, nw_info)
nw_node = nwScan.make_node(* (node.inputs[:1] + nw_outer + # Do not call make_node for test_value
node.inputs[1:])) nw_node = nwScan(*(node.inputs[:1] + nw_outer + node.inputs[1:]),
**dict(return_list=True))[0].owner
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
zip(node.outputs, nw_node.outputs), zip(node.outputs, nw_node.outputs),
remove=[node], remove=[node],
reason='scan_push_computation_out') reason='scanOp_pushout_seqs_ops')
return True return True
elif (to_keep == [] and elif (to_keep == [] and
not op.as_while and not op.as_while and
...@@ -510,8 +528,8 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -510,8 +528,8 @@ class PushOutSeqScan(gof.Optimizer):
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
replace_with.items(), replace_with.items(),
remove=[node], remove=[node],
reason='scan_push_seq_computation_out') reason='scanOp_pushout_seqs_ops')
return True
else: else:
return False return False
...@@ -563,12 +581,13 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -563,12 +581,13 @@ class ScanInplaceOptimizer(Optimizer):
info, info,
typeConstructor=self.typeConstructor) typeConstructor=self.typeConstructor)
new_outs = new_op.make_node(*inputs).outputs # Do not call make_node for test_value
new_outs = new_op(*inputs, **dict(return_list=True))
try: try:
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
zip(node.outputs, new_outs), zip(node.outputs, new_outs),
remove=[node], remove=[node],
reason=self.__class__.__name__) reason='scanOp_make_inplace')
op = new_op op = new_op
node = new_outs[0].owner node = new_outs[0].owner
except InconsistencyError, e: except InconsistencyError, e:
...@@ -847,9 +866,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -847,9 +866,8 @@ class ScanSaveMem(gof.Optimizer):
nw_inputs[0] = nw_steps nw_inputs[0] = nw_steps
# 3.2 check orphane outputs to see if we can eliminate any # 3.2 check orphane outputs to see if we can eliminate any
required, not_required = \ required, not_required = scan_utils.scan_can_remove_outs(
scan_utils.scan_can_remove_outs(node.op, node.op, orphane_outs)
orphane_outs)
# 3.3. compose replace pairs for those nodes that need not # 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required # to store everything in memory ( or ar orphane and required
# by the inner function .. ) # by the inner function .. )
...@@ -947,9 +965,10 @@ class ScanSaveMem(gof.Optimizer): ...@@ -947,9 +965,10 @@ class ScanSaveMem(gof.Optimizer):
# I need to make sure I'm not reapplying the same optimization # I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that # twice since bad things usually happen if I do that
info['_scan_savemem_visited'] = True info['_scan_savemem_visited'] = True
new_outs = scan_op.Scan(inps,
outs, # Do not call make_node for test_value
info).make_node(*node_ins).outputs new_outs = scan_op.Scan(inps, outs, info)(*node_ins,
**dict(return_list=True))
old_new = [] old_new = []
# 3.7 Get replace pairs for those outputs that do not change # 3.7 Get replace pairs for those outputs that do not change
...@@ -979,8 +998,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -979,8 +998,7 @@ class ScanSaveMem(gof.Optimizer):
nw_slice, nw_slice,
lambda entry: isinstance(entry, lambda entry: isinstance(entry,
tensor.Variable)) tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos], new_o = subtens(new_outs[nw_pos], *sl_ins)
*sl_ins).outputs[0]
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[::cnf_slice[1]] new_o = new_o[::cnf_slice[1]]
replaced_outs.append(idx) replaced_outs.append(idx)
...@@ -1011,16 +1029,14 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1011,16 +1029,14 @@ class ScanSaveMem(gof.Optimizer):
position = (cnf_slice[0] - nw_steps - position = (cnf_slice[0] - nw_steps -
init_l[pos] + store_steps[pos]) init_l[pos] + store_steps[pos])
nw_slice = (sanitize(position),) + \ nw_slice = (sanitize(position),) + tuple(
tuple(old_slices[1:]) old_slices[1:])
subtens = tensor.Subtensor(nw_slice) subtens = tensor.Subtensor(nw_slice)
sl_ins = tensor.Subtensor.collapse( sl_ins = tensor.Subtensor.collapse(
nw_slice, nw_slice,
lambda entry: isinstance(entry, lambda entry: isinstance(entry,
tensor.Variable)) tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos], new_o = subtens(new_outs[nw_pos], *sl_ins)
*sl_ins).outputs[0]
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[::cnf_slice[1]] new_o = new_o[::cnf_slice[1]]
old_new += [(old, new_o)] old_new += [(old, new_o)]
...@@ -1042,7 +1058,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1042,7 +1058,7 @@ class ScanSaveMem(gof.Optimizer):
remove.append(node) remove.append(node)
fgraph.replace_all_validate_remove(old_new, fgraph.replace_all_validate_remove(old_new,
remove, remove,
reason='scan_save_mem') reason='scanOp_save_mem')
def apply(self, fgraph): def apply(self, fgraph):
...@@ -1230,7 +1246,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1230,7 +1246,7 @@ class ScanMerge(gof.Optimizer):
proposal = self.merge(subset) proposal = self.merge(subset)
fgraph.replace_all_validate_remove(proposal, fgraph.replace_all_validate_remove(proposal,
remove=subset, remove=subset,
reason='scan_merge') reason='scanOp_merge')
def has_duplicates(l): def has_duplicates(l):
...@@ -1592,10 +1608,8 @@ class PushOutDot1(gof.Optimizer): ...@@ -1592,10 +1608,8 @@ class PushOutDot1(gof.Optimizer):
old = node.outputs[pos].clients[0][0].outputs[0] old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out)) old_new.append((old, new_out))
old_new += zip(node.outputs[pos+1:], new_outs[pos:]) old_new += zip(node.outputs[pos+1:], new_outs[pos:])
fgraph.replace_all_validate_remove(old_new, fgraph.replace_all_validate_remove(
remove = [node], old_new, remove=[node], reason='scan_pushout_dot1')
reason='PushOutDot1')
# I've added an equilibrium because later scan optimization in the sequence # I've added an equilibrium because later scan optimization in the sequence
...@@ -1628,6 +1642,7 @@ scan_seqopt1.register('scanOp_remove_constants_and_unused_inputs0', ...@@ -1628,6 +1642,7 @@ scan_seqopt1.register('scanOp_remove_constants_and_unused_inputs0',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
1, 1,
'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
...@@ -1662,10 +1677,11 @@ scan_seqopt2.register('constant_folding_for_scan2', ...@@ -1662,10 +1677,11 @@ scan_seqopt2.register('constant_folding_for_scan2',
'scan') 'scan')
scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs0', scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs1',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
2, 2,
'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
...@@ -1684,12 +1700,14 @@ scan_seqopt2.register('scanop_remove_constants_and_unused_inputs2', ...@@ -1684,12 +1700,14 @@ scan_seqopt2.register('scanop_remove_constants_and_unused_inputs2',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
5, 5,
'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
scan_seqopt2.register('scanOp_merge_inouts', scan_seqopt2.register('scanOp_merge_inouts',
opt.in2out(scan_merge_inouts, ignore_newtrees=True), opt.in2out(scan_merge_inouts, ignore_newtrees=True),
6, 6,
'scan_merge_inouts',
'fast_run', 'fast_run',
'scan') 'scan')
...@@ -1707,5 +1725,6 @@ scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs3', ...@@ -1707,5 +1725,6 @@ scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs3',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True), ignore_newtrees=True),
8, 8,
'remove_constants_and_unused_inputs_scan',
'fast_run', 'fast_run',
'scan') 'scan')
...@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph # shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately # It will call infer_shape and set_shape appropriately
dummy_fgraph = None dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner) shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = [] ret = []
for o in outs: for o in outs:
......
...@@ -1545,6 +1545,12 @@ class T_Scan(unittest.TestCase): ...@@ -1545,6 +1545,12 @@ class T_Scan(unittest.TestCase):
x0 = theano.tensor.vector('x0') x0 = theano.tensor.vector('x0')
y0 = theano.tensor.vector('y0') y0 = theano.tensor.vector('y0')
W_in1.tag.test_value = vW_in1
u1.tag.test_value = v_u1
u2.tag.test_value = v_u2
x0.tag.test_value = v_x0
y0.tag.test_value = v_y0
def f_rnn_cmpl(u1_t, def f_rnn_cmpl(u1_t,
u2_tm1, u2_tm1,
u2_t, u2_t,
...@@ -1553,11 +1559,21 @@ class T_Scan(unittest.TestCase): ...@@ -1553,11 +1559,21 @@ class T_Scan(unittest.TestCase):
y_tm1, y_tm1,
y_tm3, y_tm3,
W_in1): W_in1):
return [theano.dot(u1_t, W_in1) + \ return [theano.dot(u1_t, W_in1) +
(u2_t + u2_tm1 * u2_tp1) * W_in2 + \ (u2_t + u2_tm1 * u2_tp1) * W_in2 +
theano.dot(x_tm1, W), theano.dot(x_tm1, W),
(y_tm1 + y_tm3) * theano.dot(x_tm1, W_out), (y_tm1 + y_tm3) * theano.dot(x_tm1, W_out),
theano.dot(u1_t, W_in1)] theano.dot(u1_t, W_in1)]
# We change the compute_test_value[_opt] flag to run the
# assert in Scan.grad() of the new scan input sequence related
# to outer_mitsot_outs, outer_sitsot_outs and
# outer_nitsot_outs. This allow to test an old Scan bug.
old1 = theano.config.compute_test_value
old2 = theano.config.compute_test_value_opt
theano.config.compute_test_value = 'raise'
theano.config.compute_test_value_opt = 'raise'
try:
cost, updates = scan_project_sum( cost, updates = scan_project_sum(
f_rnn_cmpl, f_rnn_cmpl,
[u1, dict(input=u2, taps=[-1, 0, 1])], [u1, dict(input=u2, taps=[-1, 0, 1])],
...@@ -1580,6 +1596,9 @@ class T_Scan(unittest.TestCase): ...@@ -1580,6 +1596,9 @@ class T_Scan(unittest.TestCase):
updates=updates, updates=updates,
no_default_updates=True, no_default_updates=True,
allow_input_downcast=True) allow_input_downcast=True)
finally:
theano.config.compute_test_value = old1
theano.config.compute_test_value_opt = old2
num_grad = multiple_outputs_numeric_grad(cost_fn, num_grad = multiple_outputs_numeric_grad(cost_fn,
[v_u1, [v_u1,
......
...@@ -2543,7 +2543,7 @@ class Alloc(gof.Op): ...@@ -2543,7 +2543,7 @@ class Alloc(gof.Op):
#change. #change.
return [gx] + [DisconnectedType()() for i in inputs[1:]] return [gx] + [DisconnectedType()() for i in inputs[1:]]
def __call__(self, val, *shapes): def __call__(self, val, *shapes, **kwargs):
""" """
If the alloc would be useless, this function returns val. If the alloc would be useless, this function returns val.
...@@ -2554,7 +2554,7 @@ class Alloc(gof.Op): ...@@ -2554,7 +2554,7 @@ class Alloc(gof.Op):
If you always want an Alloc node, call make_node. If you always want an Alloc node, call make_node.
""" """
ret = super(Alloc, self).__call__(val, *shapes) ret = super(Alloc, self).__call__(val, *shapes, **kwargs)
try: try:
# It makes optimization difficult when useless allocs are thrown # It makes optimization difficult when useless allocs are thrown
# into the graph at every stage of optimization. This little logic # into the graph at every stage of optimization. This little logic
......
...@@ -49,14 +49,24 @@ theano.configparser.AddConfigVar('on_shape_error', ...@@ -49,14 +49,24 @@ theano.configparser.AddConfigVar('on_shape_error',
def out2in(*local_opts): def out2in(*local_opts):
"""WRITEME """ """WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='out_to_in', order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace) failure_callback=TopoOptimizer.warn_inplace)
def in2out(*local_opts, **kwargs): def in2out(*local_opts, **kwargs):
"""WRITEME """ """WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='in_to_out', order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace, failure_callback=TopoOptimizer.warn_inplace,
**kwargs) **kwargs)
...@@ -384,10 +394,12 @@ def local_dimshuffle_lift(node): ...@@ -384,10 +394,12 @@ def local_dimshuffle_lift(node):
input = node.inputs[0] input = node.inputs[0]
inode = input.owner inode = input.owner
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1): if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
return inode.op.make_node(*[DimShuffle(input.type.broadcastable, # Don't use make_node to have tag.test_value set.
ret = inode.op(*[DimShuffle(input.type.broadcastable,
op.new_order, op.new_order,
op.inplace)(input) for input in op.inplace)(input) for input in
inode.inputs]).outputs inode.inputs], **dict(return_list=True))
return ret
if inode and isinstance(inode.op, DimShuffle): if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order] op.new_order]
...@@ -397,8 +409,9 @@ def local_dimshuffle_lift(node): ...@@ -397,8 +409,9 @@ def local_dimshuffle_lift(node):
iinput.type.ndim): iinput.type.ndim):
return [iinput] return [iinput]
else: else:
return DimShuffle(iinput.type.broadcastable, new_order, ret = DimShuffle(iinput.type.broadcastable, new_order,
inplace).make_node(iinput).outputs inplace)(iinput, **dict(return_list=True))
return ret
@register_canonicalize @register_canonicalize
...@@ -437,8 +450,10 @@ def dimshuffle_as_view(node): ...@@ -437,8 +450,10 @@ def dimshuffle_as_view(node):
#Step 60 is the inplace optimization stage. #Step 60 is the inplace optimization stage.
compile.optdb.register('dimshuffle_as_view', compile.optdb.register('dimshuffle_as_view',
TopoOptimizer(dimshuffle_as_view, TopoOptimizer(
failure_callback=TopoOptimizer.warn_inplace), 60, dimshuffle_as_view,
failure_callback=TopoOptimizer.warn_inplace),
60,
'fast_run', 'inplace') 'fast_run', 'inplace')
register_canonicalize(local_dimshuffle_lift) register_canonicalize(local_dimshuffle_lift)
register_specialize(local_dimshuffle_lift) register_specialize(local_dimshuffle_lift)
...@@ -771,7 +786,8 @@ class ShapeFeature(object): ...@@ -771,7 +786,8 @@ class ShapeFeature(object):
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]: if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one return self.lscalar_one
else: else:
return Shape_i(i).make_node(r).outputs[0] # Do not call make_node for test_value
return Shape_i(i)(r)
def shape_tuple(self, r): def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r""" """Return a tuple of symbolic shape vars for tensor variable r"""
...@@ -970,9 +986,9 @@ class ShapeFeature(object): ...@@ -970,9 +986,9 @@ class ShapeFeature(object):
# shape var -> graph v # shape var -> graph v
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node) self.on_import(fgraph, node, reason='on_attach')
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.shape_of: if node.outputs[0] in self.shape_of:
# this is a revert, not really an import # this is a revert, not really an import
for r in node.outputs + node.inputs: for r in node.outputs + node.inputs:
...@@ -1933,7 +1949,8 @@ def local_subtensor_merge(node): ...@@ -1933,7 +1949,8 @@ def local_subtensor_merge(node):
sl_ins = Subtensor.collapse( sl_ins = Subtensor.collapse(
merged_slices, merged_slices,
lambda x: isinstance(x, T.Variable)) lambda x: isinstance(x, T.Variable))
out = subtens.make_node(x, *sl_ins).outputs[0] # Do not call make_node for test_value
out = subtens(x, *sl_ins)
return [out] return [out]
...@@ -4583,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4583,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif ii in tmp_input: elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else: else:
tmp_s_input.append(scalar.Scalar( tmp = scalar.Scalar(ii.dtype).make_variable()
ii.dtype).make_variable()) try:
tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0]
except AttributeError:
pass
tmp_s_input.append(tmp)
tmp_input.append(ii) tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1]) tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input) s_op = i.owner.op.scalar_op(*tmp_s_input)
...@@ -4634,6 +4655,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4634,6 +4655,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s = s_inputs[inputs.index(i)] s = s_inputs[inputs.index(i)]
else: else:
s = scalar.Scalar(i.dtype).make_variable() s = scalar.Scalar(i.dtype).make_variable()
try:
v = gof.op.get_test_value(i)
if v.size > 0:
s.tag.test_value = gof.op.get_test_value(i).flatten()[0]
except AttributeError:
pass
inputs.append(i) inputs.append(i)
s_inputs.append(s) s_inputs.append(s)
s_g.append(s) s_g.append(s)
...@@ -4667,7 +4695,8 @@ your code will run correctly, but may be slower.""") ...@@ -4667,7 +4695,8 @@ your code will run correctly, but may be slower.""")
C = scalar.Composite(s_inputs, [s_new_out]) C = scalar.Composite(s_inputs, [s_new_out])
#create the new node. #create the new node.
n = OP(C).make_node(*inputs) #Do not call make_node to have test_value
n = OP(C)(*inputs).owner
assert len(n.outputs) == 1 assert len(n.outputs) == 1
assert node.outputs[0].dtype == n.outputs[0].dtype assert node.outputs[0].dtype == n.outputs[0].dtype
...@@ -4728,9 +4757,11 @@ if config.tensor.local_elemwise_fusion: ...@@ -4728,9 +4757,11 @@ if config.tensor.local_elemwise_fusion:
_logger.debug("enabling optimization fusion elemwise in fast_run") _logger.debug("enabling optimization fusion elemwise in fast_run")
compile.optdb.register('elemwise_fusion', compile.optdb.register('elemwise_fusion',
FusionOptimizer(local_elemwise_fusion), 71.00, FusionOptimizer(local_elemwise_fusion), 71.00,
'fast_run', 'fusion', 'local_elemwise_fusion') 'fast_run', 'fusion', 'local_elemwise_fusion',
'FusionOptimizer')
else: else:
_logger.debug("not enabling optimization fusion elemwise in fast_run") _logger.debug("not enabling optimization fusion elemwise in fast_run")
compile.optdb.register('elemwise_fusion', compile.optdb.register('elemwise_fusion',
FusionOptimizer(local_elemwise_fusion), 71.00, FusionOptimizer(local_elemwise_fusion), 71.00,
'fusion', 'local_elemwise_fusion') 'fusion', 'local_elemwise_fusion',
'FusionOptimizer')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论