提交 a351e3db authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1482 from nouiz/rnade

Scan crash fix
......@@ -482,6 +482,14 @@ import theano and print the config variable, as in:
This flag's value cannot be modified during the program execution.
.. attribute:: optimizer_verbose
Bool value: either True or False
Default: False
When True, we print on the stdout the optimization applied.
.. attribute:: nocleanup
Bool value: either True or False
......@@ -630,6 +638,12 @@ import theano and print the config variable, as in:
this Op
- ``'raise'`` will raise an Exception
.. attribute:: config.compute_test_value_opt
As ``compute_test_value``, but it is the value used during Theano
optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization.
.. attribute:: config.exception_verbosity
String Value: ``'low'``, ``'high'``.
......
......@@ -1428,21 +1428,25 @@ class _VariableEquivalenceTracker(object):
self.reasons = {}
self.replaced_by = {}
self.event_list = []
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph):
assert fgraph is self.fgraph
self.fgraph = None
def on_prune(self, fgraph, node):
self.event_list.append(_FunctionGraphEvent('prune', node))
def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('prune', node,
reason=reason))
#print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes
assert node not in self.inactive_nodes
self.active_nodes.remove(node)
self.inactive_nodes.add(node)
def on_import(self, fgraph, node):
self.event_list.append(_FunctionGraphEvent('import', node))
def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('import', node,
reason=reason))
#print 'NEW NODE', node, id(node)
assert node not in self.active_nodes
......@@ -2114,7 +2118,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# optimize the fgraph
compute_test_value_orig = theano.config.compute_test_value
try:
theano.config.compute_test_value = "off"
theano.config.compute_test_value = theano.config.compute_test_value_opt
optimizer(fgraph)
theano.compile.function_module.insert_deepcopy(fgraph, inputs,
......
......@@ -1018,7 +1018,7 @@ class FunctionMaker(object):
compute_test_value_orig = theano.config.compute_test_value
add_stack_trace_on_call = gof.Op.add_stack_trace_on_call
try:
theano.config.compute_test_value = "off"
theano.config.compute_test_value = theano.config.compute_test_value_opt
gof.Op.add_stack_trace_on_call = False
start_optimizer = time.time()
optimizer_profile = optimizer(fgraph)
......
......@@ -157,6 +157,11 @@ AddConfigVar('optimizer',
EnumStr('fast_run', 'merge', 'fast_compile', 'None'),
in_c_key=False)
AddConfigVar('optimizer_verbose',
"If True, we print all optimization being applied",
BoolParam(False),
in_c_key=False)
AddConfigVar('on_opt_error',
("What to do when an optimization crashes: warn and skip it, raise "
"the exception, or fall into the pdb debugger."),
......@@ -379,10 +384,17 @@ AddConfigVar('compute_test_value',
"Constants, SharedVariables and the tag 'test_value' as inputs "
"to the function. This helps the user track down problems in the "
"graph before it gets optimized."),
EnumStr('off', 'ignore', 'warn', 'raise'),
EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False)
AddConfigVar('compute_test_value_opt',
("For debugging Theano optimization only."
" Same as compute_test_value, but is used"
" during Theano optimization"),
EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False)
"""Note to developers:
Generally your exceptions should use an apply node's __str__
method when exception_verbosity == 'low'. When exception_verbosity
......
......@@ -380,7 +380,7 @@ if 0:
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def on_import(self, fgraph, app):
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
......@@ -410,7 +410,7 @@ if 0:
self.stale_droot = True
def on_prune(self, fgraph, app):
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app)
......@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper):
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def on_import(self, fgraph, app):
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import")
......@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self.stale_droot = True
def on_prune(self, fgraph, app):
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import")
self.debug_all_apps.remove(app)
......
......@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
types that it can raise
"""
import sys
from theano.gof import graph
from theano.gof import utils
from theano.gof import toolbox
......@@ -16,6 +17,7 @@ NullType = None
from theano.gof.python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet
class InconsistencyError(Exception):
"""
This exception should be thrown by listeners to FunctionGraph when the
......@@ -82,7 +84,8 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set.
self._features = []
# All apply nodes in the subgraph defined by inputs and outputs are cached in this field
# All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self.apply_nodes = set()
# Ditto for variable nodes
......@@ -104,7 +107,7 @@ class FunctionGraph(utils.object2):
self.__setup_r__(input)
self.variables.add(input)
self.__import_r__(outputs)
self.__import_r__(outputs, reason="init")
for i, output in enumerate(outputs):
output.clients.append(('output', i))
......@@ -112,12 +115,12 @@ class FunctionGraph(utils.object2):
self.variable_locks = {}
self.profile = None
### Setup a Variable ###
def __setup_r__(self, r):
# sets up r so it belongs to this fgraph
if hasattr(r, 'fgraph') and r.fgraph is not None and r.fgraph is not self:
if (hasattr(r, 'fgraph') and
r.fgraph is not None and
r.fgraph is not self):
raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self
r.clients = []
......@@ -165,13 +168,13 @@ class FunctionGraph(utils.object2):
self.inputs = None
self.outputs = None
### clients ###
def clients(self, r):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have r as input at index i.
Tell differently, a list of (node,i) such that each node have
r as input at index i.
"""
return r.clients
......@@ -184,12 +187,15 @@ class FunctionGraph(utils.object2):
"""
if set(r.clients).intersection(set(new_clients)):
print >> sys.stderr, 'ERROR: clients intersect!'
print >> sys.stderr, ' RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients]
print >> sys.stderr, ' NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients]
print >> sys.stderr, ' RCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in r.clients]
print >> sys.stderr, ' NCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in new_clients]
assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients
def __remove_clients__(self, r, clients_to_remove, prune = True):
def __remove_clients__(self, r, clients_to_remove,
prune=True, reason=None):
""" WRITEME
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
......@@ -202,18 +208,16 @@ class FunctionGraph(utils.object2):
print >> sys.stderr, 'ERROR: DUPLICATE CLIENT ENTRY...'
print >> sys.stderr, ' ENTRY', repr(entry), type(entry[0])
print >> sys.stderr, ' CLIENTS', repr(r.clients)
assert entry not in r.clients # an op,i pair should be unique
assert entry not in r.clients # an op,i pair should be unique
if not r.clients:
if prune:
self.__prune_r__([r])
self.__prune_r__([r], reason)
return False
return True
return False
### import ###
def __import_r__(self, variables):
def __import_r__(self, variables, reason):
global NullType
if NullType is None:
from null_type import NullType
......@@ -222,17 +226,18 @@ class FunctionGraph(utils.object2):
for apply_node in [r.owner for r in variables if r.owner is not None]:
if apply_node not in r_owner_done:
r_owner_done.add(apply_node)
self.__import__(apply_node)
self.__import__(apply_node, reason=reason)
for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type,NullType):
raise TypeError("Computation graph contains a NaN. "+r.type.why_null)
if isinstance(r.type, NullType):
raise TypeError("Computation graph contains a NaN. " +
r.type.why_null)
raise MissingInputError("Undeclared input", r)
if not getattr(r, 'fgraph', None) is self:
self.__setup_r__(r)
self.variables.add(r)
def __import__(self, apply_node, check = True):
def __import__(self, apply_node, check=True, reason=None):
node = apply_node
# We import the nodes in topological order. We only are interested
......@@ -248,7 +253,9 @@ class FunctionGraph(utils.object2):
for r in node.inputs:
if hasattr(r, 'fgraph') and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r)
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if (r.owner is None and
not isinstance(r, graph.Constant) and
r not in self.inputs):
#Verbose error message
#Show a complete chain of variables from the missing input to an output
......@@ -328,20 +335,18 @@ class FunctionGraph(utils.object2):
self.variables.add(input)
self.__add_clients__(input, [(node, i)])
assert node.fgraph is self
self.execute_callbacks('on_import', node)
self.execute_callbacks('on_import', node, reason)
### prune ###
def __prune_r__(self, variables):
def __prune_r__(self, variables, reason=None):
# Prunes the owners of the variables.
for node in set(r.owner for r in variables if r.owner is not None):
self.__prune__(node)
self.__prune__(node, reason)
for r in variables:
if not r.clients and r in self.variables:
self.variables.remove(r)
def __prune__(self, apply_node):
def __prune__(self, apply_node, reason=None):
node = apply_node
if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node)
......@@ -356,16 +361,13 @@ class FunctionGraph(utils.object2):
return
self.apply_nodes.remove(node)
self.variables.difference_update(node.outputs)
self.execute_callbacks('on_prune', node)
self.execute_callbacks('on_prune', node, reason)
for i, input in enumerate(node.inputs):
self.__remove_clients__(input, [(node, i)])
self.__remove_clients__(input, [(node, i)], reason=reason)
#self.__prune_r__(node.inputs)
### change input ###
def change_input(self, node, i, new_r, reason=None):
"""WRITEME
Changes node.inputs[i] to new_r.
......@@ -381,42 +383,45 @@ class FunctionGraph(utils.object2):
r = self.outputs[i]
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the"
" same as the type of the original Variable.",
r, new_r)
" same as the type of the original Variable.",
r, new_r)
self.outputs[i] = new_r
else:
if node.fgraph is not self:
raise Exception("Cannot operate on %s because it does not"
" belong to this FunctionGraph" % node)
" belong to this FunctionGraph" % node)
r = node.inputs[i]
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the"
" same as the type of the original Variable.",
r, new_r)
" same as the type of the original Variable.",
r, new_r)
node.inputs[i] = new_r
if r is new_r:
return
self.__import_r__([new_r])
self.__import_r__([new_r], reason=reason)
self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False)
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
self.execute_callbacks('on_change_input', node, i, r, new_r, reason=reason)
self.execute_callbacks('on_change_input', node, i,
r, new_r, reason=reason)
if prune:
self.__prune_r__([r])
self.__prune_r__([r], reason=reason)
### replace ###
def replace(self, r, new_r, reason=None):
def replace(self, r, new_r, reason=None, verbose=None):
""" WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
"""
if verbose is None:
verbose = config.optimizer_verbose
if verbose:
print reason, r, new_r
if r.fgraph is not self:
raise Exception("Cannot replace %s because it does not belong to this FunctionGraph" % r, str(reason))
if not r.type == new_r.type:
......@@ -426,7 +431,7 @@ class FunctionGraph(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
return
for node, i in list(r.clients): # copy the client list for iteration
for node, i in list(r.clients): # copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason)
......@@ -440,11 +445,9 @@ class FunctionGraph(utils.object2):
for r, new_r in pairs:
self.replace(r, new_r, reason=reason)
def extend(self, feature):
warnings.warn("FunctionGraph.extend is deprecatd. It has been "
"renamed to FunctionGraph.attach_feature")
"renamed to FunctionGraph.attach_feature")
return self.attach_feature(feature)
def attach_feature(self, feature):
......@@ -455,7 +458,7 @@ class FunctionGraph(utils.object2):
# Filter out literally identical features
if feature in self._features:
return # the feature is already present
return # the feature is already present
# Filter out functionally identical features.
# Features may use their on_attach method to raise
......@@ -481,7 +484,9 @@ class FunctionGraph(utils.object2):
"""WRITEME
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method is defined.
Calls feature.on_detach(function_graph) if an on_detach method
is defined.
"""
try:
self._features.remove(feature)
......@@ -491,9 +496,7 @@ class FunctionGraph(utils.object2):
if detach is not None:
detach(self)
### callback utils ###
def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME
Calls
......@@ -518,7 +521,6 @@ class FunctionGraph(utils.object2):
else:
raise
def collect_callbacks(self, name, *args):
"""WRITEME
Returns a dictionary d such that:
......@@ -534,9 +536,7 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args)
return d
### misc ###
def toposort(self):
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
......@@ -552,8 +552,8 @@ class FunctionGraph(utils.object2):
if len(self.apply_nodes) < 2:
# optimization
# when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker produces
# 1-element graphs.
# This special case happens a lot because the OpWiseCLinker
# produces 1-element graphs.
return list(self.apply_nodes)
fg = self
......@@ -568,30 +568,33 @@ class FunctionGraph(utils.object2):
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that all
clients of any destroyed inputs have already computed their outputs.
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their
outputs.
:note: This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
"""
ords = OrderedDict()
ords = OrderedDict()
assert isinstance(self._features, list)
for feature in self._features:
if hasattr(feature, 'orderings'):
orderings = feature.orderings(self)
if not isinstance(orderings, OrderedDict):
raise TypeError("Non-deterministic return value from " \
+str(feature.orderings) \
+". Nondeterministic object is "+str(orderings))
raise TypeError("Non-deterministic return value from " +
str(feature.orderings) +
". Nondeterministic object is " +
str(orderings))
for node, prereqs in orderings.items():
if not isinstance(prereqs, (list, OrderedSet)):
raise TypeError("prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic.")
raise TypeError(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic.")
ords.setdefault(node, []).extend(prereqs)
# eliminate duplicate prereqs
for (node,prereqs) in ords.items():
for (node, prereqs) in ords.items():
ords[node] = list(OrderedSet(prereqs))
return ords
......@@ -624,34 +627,48 @@ class FunctionGraph(utils.object2):
if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes)
raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess)
raise Exception(
"The nodes are inappropriately cached. missing, in excess: ",
missing, excess)
for node in nodes:
if node.fgraph is not self:
raise Exception("Node should belong to the FunctionGraph.", node)
raise Exception("Node should belong to the FunctionGraph.",
node)
for i, variable in enumerate(node.inputs):
if variable.fgraph is not self:
raise Exception("Input of node should belong to the FunctionGraph.", variable, (node, i))
raise Exception(
"Input of node should belong to the FunctionGraph.",
variable, (node, i))
if (node, i) not in variable.clients:
raise Exception("Inconsistent clients list.", (node, i), variable.clients)
raise Exception("Inconsistent clients list.",
(node, i), variable.clients)
variables = set(graph.variables(self.inputs, self.outputs))
if set(self.variables) != variables:
missing = variables.difference(self.variables)
excess = self.variables.difference(variables)
raise Exception("The variables are inappropriately cached. missing, in excess: ", missing, excess)
raise Exception(
"The variables are inappropriately cached. missing, in excess: ",
missing, excess)
for variable in variables:
if variable.owner is None and variable not in self.inputs and not isinstance(variable, graph.Constant):
if (variable.owner is None and
variable not in self.inputs and
not isinstance(variable, graph.Constant)):
raise Exception("Undeclared input.", variable)
if variable.fgraph is not self:
raise Exception("Variable should belong to the FunctionGraph.", variable)
raise Exception("Variable should belong to the FunctionGraph.",
variable)
for node, i in variable.clients:
if node == 'output':
if self.outputs[i] is not variable:
raise Exception("Inconsistent clients list.", variable, self.outputs[i])
raise Exception("Inconsistent clients list.",
variable, self.outputs[i])
continue
if node not in nodes:
raise Exception("Client not in FunctionGraph.", variable, (node, i))
raise Exception("Client not in FunctionGraph.",
variable, (node, i))
if node.inputs[i] is not variable:
raise Exception("Inconsistent clients list.", variable, node.inputs[i])
raise Exception("Inconsistent clients list.",
variable, node.inputs[i])
def __str__(self):
return "[%s]" % ", ".join(graph.as_string(self.inputs, self.outputs))
......@@ -659,9 +676,7 @@ class FunctionGraph(utils.object2):
def __repr__(self):
return self.__str__()
### clone ###
def clone(self):
"""WRITEME"""
return self.clone_get_equiv()[0]
......@@ -671,7 +686,7 @@ class FunctionGraph(utils.object2):
equiv = graph.clone_get_equiv(self.inputs, self.outputs)
self.check_integrity()
e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs])
[equiv[o] for o in self.outputs])
e.check_integrity()
for feature in self._features:
e.attach_feature(feature)
......
......@@ -13,6 +13,7 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
import logging
import sys
import warnings
import theano
......@@ -408,6 +409,9 @@ class PureOp(object):
elif config.compute_test_value == 'ignore':
# silently skip test
run_perform = False
elif config.compute_test_value == 'pdb':
import pdb
pdb.post_mortem(sys.exc_info()[2])
else:
raise ValueError('%s is invalid for option config.compute_Test_value' % config.compute_test_value)
......@@ -638,8 +642,11 @@ def get_test_value(v):
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
v_tensor = theano.tensor.as_tensor_variable(v)
return PureOp._get_test_value(v_tensor)
if not isinstance(v, graph.Variable):
v_var = theano.tensor.as_tensor_variable(v)
else:
v_var = v
return PureOp._get_test_value(v_var)
def missing_test_message(msg):
......
......@@ -421,7 +421,7 @@ class MergeFeature(object):
self.blacklist = []
for node in fgraph.toposort():
self.on_import(fgraph, node)
self.on_import(fgraph, node, "on_attach")
def on_change_input(self, fgraph, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct
......@@ -433,14 +433,14 @@ class MergeFeature(object):
if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r)
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
for c in node.inputs:
if isinstance(c, graph.Constant):
self.process_constant(fgraph, c)
self.process_node(fgraph, node)
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
......@@ -548,7 +548,7 @@ class MergeOptimizer(Optimizer):
except InconsistencyError:
success = False
fgraph.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
(pairs[0][0].owner, pairs[0][1].owner))
if success:
break
......@@ -1027,7 +1027,7 @@ class PatternSub(LocalOptimizer):
else:
return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True,
self.pdb)
self.pdb)
if u:
p = self.out_pattern
new = build(p, u)
......@@ -1165,10 +1165,10 @@ class NavigatorOptimizer(Optimizer):
class Updater:
if importer is not None:
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
importer(node)
if pruner is not None:
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
pruner(node)
if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r):
......@@ -1357,7 +1357,7 @@ class ChangeTracker:
def __init__(self):
self.changed = False
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
self.changed = True
def on_change_input(self, fgraph, node, i, r, new_r):
......
import sys
import time
from theano import config
from theano.gof.python25 import partial
from theano.gof.python25 import OrderedDict
from theano.gof import graph
class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
......@@ -57,7 +56,7 @@ class Feature(object):
functionality that it installed into the function_graph.
"""
def on_import(self, function_graph, node):
def on_import(self, function_graph, node, reason):
"""
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
......@@ -66,7 +65,7 @@ class Feature(object):
you should do this by implementing on_attach.
"""
def on_prune(self, function_graph, node):
def on_prune(self, function_graph, node, reason):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
......@@ -98,11 +97,11 @@ class Bookkeeper(Feature):
def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_import(fgraph, node)
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_prune(fgraph, node)
self.on_prune(fgraph, node, 'Bookkeeper.detach')
class History(Feature):
......@@ -199,11 +198,14 @@ class ReplaceValidate(History, Validator):
def replace_validate(self, fgraph, r, new_r, reason=None):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason)
def replace_all_validate(self, fgraph, replacements, reason=None):
def replace_all_validate(self, fgraph, replacements,
reason=None, verbose=None):
chk = fgraph.checkpoint()
if verbose is None:
verbose = config.optimizer_verbose
for r, new_r in replacements:
try:
fgraph.replace(r, new_r, reason=reason)
fgraph.replace(r, new_r, reason=reason, verbose=False)
except Exception, e:
if ('The type of the replacement must be the same' not in
str(e) and 'does not belong to this FunctionGraph' not in str(e)):
......@@ -219,6 +221,8 @@ class ReplaceValidate(History, Validator):
except Exception, e:
fgraph.revert(chk)
raise
if verbose:
print reason, r, new_r
return chk
def replace_all_validate_remove(self, fgraph, replacements,
......@@ -267,7 +271,7 @@ class NodeFinder(dict, Bookkeeper):
del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph)
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
try:
self.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
......@@ -280,7 +284,7 @@ class NodeFinder(dict, Bookkeeper):
print >> sys.stderr, 'OFFENDING node not hashable'
raise e
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
try:
nodes = self[node.op]
except TypeError: # node.op is unhashable
......@@ -312,13 +316,13 @@ class PrintListener(Feature):
if self.active:
print "-- detaching from: ", fgraph
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
if self.active:
print "-- importing: %s" % node
print "-- importing: %s, reason: %s" % (node, reason)
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
if self.active:
print "-- pruning: %s" % node
print "-- pruning: %s, reason: %s" % (node, reason)
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active:
......
......@@ -2953,7 +2953,6 @@ class GpuJoin(tensor.Join, GpuOp):
axis = inputs[0]
n_cndas = len(inputs[1:])
input_1 = inputs[1]
axis = inputs[0]
fail = sub['fail']
out = out_[0]
......
......@@ -137,9 +137,9 @@ class HintsFeature(object):
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.hints = {}
for node in fgraph.toposort():
self.on_import(fgraph, node)
self.on_import(fgraph, node, "on_attach")
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.hints:
# this is a revert, not really an import
for r in node.outputs + node.inputs:
......
......@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner)
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = []
for o in outs:
......
......@@ -183,6 +183,24 @@ class Scalar(Type):
def dtype_specs(self):
try:
# To help debug dtype/typenum problem, here is code to get
# the list of numpy typenum. This list change between 32
# and 64 bit platform and probably also also between
# Windows and Linux.
# NOTE: equivalent type on a platform can have different typenum.
# This is the source of all dtype/typenum problem found up to
# now, as Theano always expect the exact typenum that
# correspond to our supported dtype.
"""
for dtype in ['int8', 'uint8', 'short', 'ushort', 'intc', 'uintc',
'longlong', 'ulonglong', 'single', 'double',
'longdouble', 'csingle', 'cdouble', 'clongdouble',
'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'float', 'double',
'int', 'uint']:
print dtype, np.zeros(1, dtype=dtype).dtype.num
"""
return { # dtype: (py_type, c_type, cls_name)
'float32': (numpy.float32, 'npy_float32', 'Float32'),
'float64': (numpy.float64, 'npy_float64', 'Float64'),
......
......@@ -101,7 +101,7 @@ def scan(fn,
The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the same
as the order of ``output_info``. For any sequence or output the
as the order of ``outputs_info``. For any sequence or output the
order of the time slices is the same as the one in which they have
been given as taps. For example if one writes the following :
......@@ -262,7 +262,7 @@ def scan(fn,
outputs will have *0 rows*. If the value is negative, ``scan``
will run backwards in time. If the ``go_backwards`` flag is already
set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n stpes is not provided, ``scan`` will figure
in time. If n_steps is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences.
......@@ -817,7 +817,7 @@ def scan(fn,
if as_while:
tmp_dummy_f_outs -= 1
if not (tmp_dummy_f_outs == n_outs or outs_info == []):
raise ValueError('Please provide None as output_info for '
raise ValueError('Please provide None as outputs_info for '
'any output that does not feed back into '
'scan (i.e. it behaves like a map) ')
......
......@@ -1581,8 +1581,30 @@ class Scan(PureOp):
if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[::-1])
outer_inp_seqs += [x[::-1] for x in self.outer_mitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_sitsot_outs(outs)]
if hasattr(inputs[0].tag, 'test_value'):
# Here we tests that the new scan input sequence all have
# the same shape[0]. This is a properties that the scan()
# fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test
# that.
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs)):
mintap = numpy.min(taps)
if hasattr(x[::-1][:mintap], 'test_value'):
assert (x[::-1][:mintap].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'):
assert (x[::-1][:-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, 'test_value'):
assert (x[::-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
outer_inp_seqs += [x[::-1][:numpy.min(taps)]
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs))]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
inner_inp_seqs = self.inner_seqs(self_inputs)
......
......@@ -66,7 +66,7 @@ def remove_constants_and_unused_inputs_scan(node):
# We only need to take care of sequences and other arguments
st = op.n_seqs
st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
st += op.n_sit_sot
st += op.n_shared_outs
op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs)
......@@ -105,8 +105,8 @@ def remove_constants_and_unused_inputs_scan(node):
elif op_ins[idx] in all_ins:
# Check for identical other sequence
identical_seqs = [x for x in nw_outer
if scan_utils.equal_computations(
[x], [node.inputs[idx + 1]])]
if scan_utils.equal_computations(
[x], [node.inputs[idx + 1]])]
if identical_seqs:
index = node.inputs.index(identical_seqs[0]) - 1
givens[op_ins[idx]] = op_ins[index]
......@@ -144,7 +144,7 @@ def remove_constants_and_unused_inputs_scan(node):
nw_info['n_seqs'] = nw_n_seqs
# DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan.make_node(*nw_outer).outputs
nw_outs = nwScan(*nw_outer, **dict(return_list=True))
return nw_outs
else:
return False
......@@ -162,7 +162,7 @@ class PushOutNonSeqScan(gof.Optimizer):
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op,
scan_op.Scan)]
scan_op.Scan)]
for node in nodelist:
self.process_node(fgraph, node)
......@@ -170,7 +170,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# this flag tells if there was any change during the last iterations
changed = True
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs)
max_iterations = 2 * len(local_fgraph.toposort()) + 3
......@@ -196,7 +196,7 @@ class PushOutNonSeqScan(gof.Optimizer):
if (numpy.all([(x in inner_non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
for x in nd.inputs]) and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
......@@ -227,7 +227,11 @@ class PushOutNonSeqScan(gof.Optimizer):
'this on theano-users list'), x)
outside_ins = [x.type.filter_variable(y) for x, y in
zip(nd.inputs, outside_ins)]
nw_outer_node = nd.op.make_node(*outside_ins)
# Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins,
**dict(return_list=True))[0].owner
# Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs):
......@@ -250,7 +254,7 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph.toposort()
if nd not in to_remove]
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
......@@ -270,8 +274,8 @@ class PushOutNonSeqScan(gof.Optimizer):
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
repl_in = repl_out.clone()
else:
......@@ -285,11 +289,15 @@ class PushOutNonSeqScan(gof.Optimizer):
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info)
nw_node = nwScan.make_node(* (node.inputs + nw_outer))
# Do not call make_node for test_value
nw_node = nwScan(*(node.inputs + nw_outer),
**dict(return_list=True))[0].owner
fgraph.replace_all_validate_remove(
zip(node.outputs, nw_node.outputs),
remove=[node],
reason='scan_push_computation_out')
reason='scanOp_pushout_nonseqs_ops')
return True
elif to_keep == []:
# Nothing in the inner graph should be kept
......@@ -310,7 +318,7 @@ class PushOutNonSeqScan(gof.Optimizer):
fgraph.replace_all_validate_remove(
replace_with.items(),
remove=[node],
reason='scan_push_computation_out')
reason='scanOp_pushout_nonseqs_ops')
else:
return False
......@@ -327,8 +335,8 @@ class PushOutSeqScan(gof.Optimizer):
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op,
scan_op.Scan)]
nodelist = [x for x in fgraph.toposort()
if isinstance(x.op, scan_op.Scan)]
for node in nodelist:
self.process_node(fgraph, node)
......@@ -336,7 +344,7 @@ class PushOutSeqScan(gof.Optimizer):
# this flag tells if there was any change during the last iterations
changed = True
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs)
max_iterations = 2 * len(local_fgraph.toposort()) + 3
......@@ -361,12 +369,12 @@ class PushOutSeqScan(gof.Optimizer):
for nd in local_fgraph.toposort():
if (isinstance(nd.op, theano.tensor.Elemwise) and
numpy.all([(x in inner_non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant) or
(x in inner_seqs)
for x in nd.inputs]) and
not nd in to_remove):
numpy.all([(x in inner_non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant) or
(x in inner_seqs)
for x in nd.inputs]) and
not nd in to_remove):
to_remove.append(nd)
outside_ins = []
for x in nd.inputs:
......@@ -376,18 +384,21 @@ class PushOutSeqScan(gof.Optimizer):
elif x in inner_seqs:
outside_ins += [outer_seqs[inner_seqs.index(x)]]
elif x in to_replace:
outside_ins += [replace_with_out[\
to_replace.index(x)]]
outside_ins += [replace_with_out[
to_replace.index(x)]]
elif isinstance(x, theano.Constant):
outside_ins += [x.clone()]
else:
raise Exception(
('Error in the `scan_pushout_non_seq_'
('Error in the `scan_pushout_seq_'
'operations`. The optimization tries '
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'), x)
nw_outer_node = nd.op.make_node(*outside_ins)
# Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins,
**dict(return_list=True))[0].owner
# Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs):
......@@ -420,10 +431,15 @@ class PushOutSeqScan(gof.Optimizer):
to_replace += [y]
replace_with_in += [y_place_holder]
replace_with_out += [new_outer]
if hasattr(new_outer.tag, "test_value"):
new_sh = new_outer.tag.test_value.shape
ref_sh = (outside_ins.tag.test_value.shape[0],)
ref_sh += nd.outputs[0].tag.test_value.shape
assert new_sh == ref_sh
changed = True
if counts >= max_iterations:
raise Exception('Error in the `scan_pushout_non_seq_operations`.'
raise Exception('Error in the `scan_pushout_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!')
# We need to check all candidate replacements and choose those that
......@@ -436,7 +452,7 @@ class PushOutSeqScan(gof.Optimizer):
clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph.toposort()
if nd not in to_remove]
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
......@@ -456,8 +472,8 @@ class PushOutSeqScan(gof.Optimizer):
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
repl_in = repl_out.clone()
else:
......@@ -473,12 +489,14 @@ class PushOutSeqScan(gof.Optimizer):
nw_info = op.info.copy()
nw_info['n_seqs'] += len(nw_inner)
nwScan = scan_op.Scan(op_ins, op_outs, nw_info)
nw_node = nwScan.make_node(* (node.inputs[:1] + nw_outer +
node.inputs[1:]))
# Do not call make_node for test_value
nw_node = nwScan(*(node.inputs[:1] + nw_outer + node.inputs[1:]),
**dict(return_list=True))[0].owner
fgraph.replace_all_validate_remove(
zip(node.outputs, nw_node.outputs),
remove=[node],
reason='scan_push_computation_out')
reason='scanOp_pushout_seqs_ops')
return True
elif (to_keep == [] and
not op.as_while and
......@@ -510,8 +528,8 @@ class PushOutSeqScan(gof.Optimizer):
fgraph.replace_all_validate_remove(
replace_with.items(),
remove=[node],
reason='scan_push_seq_computation_out')
reason='scanOp_pushout_seqs_ops')
return True
else:
return False
......@@ -532,7 +550,7 @@ class ScanInplaceOptimizer(Optimizer):
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and
x.op.info['gpu'] == self.gpu_flag)]
x.op.info['gpu'] == self.gpu_flag)]
for scan_idx in xrange(len(scan_nodes)):
node = scan_nodes[scan_idx]
op = node.op
......@@ -563,12 +581,13 @@ class ScanInplaceOptimizer(Optimizer):
info,
typeConstructor=self.typeConstructor)
new_outs = new_op.make_node(*inputs).outputs
# Do not call make_node for test_value
new_outs = new_op(*inputs, **dict(return_list=True))
try:
fgraph.replace_all_validate_remove(
zip(node.outputs, new_outs),
remove=[node],
reason=self.__class__.__name__)
reason='scanOp_make_inplace')
op = new_op
node = new_outs[0].owner
except InconsistencyError, e:
......@@ -720,7 +739,7 @@ class ScanSaveMem(gof.Optimizer):
except KeyError:
length = out.shape[0]
cf_slice = tensor.get_canonical_form_slice(
this_slice[0], length)
this_slice[0], length)
slices[i] += [(cf_slice, this_slice)]
if (isinstance(this_slice[0], slice) and
......@@ -847,9 +866,8 @@ class ScanSaveMem(gof.Optimizer):
nw_inputs[0] = nw_steps
# 3.2 check orphane outputs to see if we can eliminate any
required, not_required = \
scan_utils.scan_can_remove_outs(node.op,
orphane_outs)
required, not_required = scan_utils.scan_can_remove_outs(
node.op, orphane_outs)
# 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required
# by the inner function .. )
......@@ -947,9 +965,10 @@ class ScanSaveMem(gof.Optimizer):
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
info['_scan_savemem_visited'] = True
new_outs = scan_op.Scan(inps,
outs,
info).make_node(*node_ins).outputs
# Do not call make_node for test_value
new_outs = scan_op.Scan(inps, outs, info)(*node_ins,
**dict(return_list=True))
old_new = []
# 3.7 Get replace pairs for those outputs that do not change
......@@ -978,9 +997,8 @@ class ScanSaveMem(gof.Optimizer):
sl_ins = tensor.Subtensor.collapse(
nw_slice,
lambda entry: isinstance(entry,
tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos],
*sl_ins).outputs[0]
tensor.Variable))
new_o = subtens(new_outs[nw_pos], *sl_ins)
if new_o.ndim > 0:
new_o = new_o[::cnf_slice[1]]
replaced_outs.append(idx)
......@@ -1009,18 +1027,16 @@ class ScanSaveMem(gof.Optimizer):
else:
position = (cnf_slice[0] - nw_steps -
init_l[pos] + store_steps[pos])
nw_slice = (sanitize(position),) + \
tuple(old_slices[1:])
init_l[pos] + store_steps[pos])
nw_slice = (sanitize(position),) + tuple(
old_slices[1:])
subtens = tensor.Subtensor(nw_slice)
sl_ins = tensor.Subtensor.collapse(
nw_slice,
lambda entry: isinstance(entry,
tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos],
*sl_ins).outputs[0]
new_o = subtens(new_outs[nw_pos], *sl_ins)
if new_o.ndim > 0:
new_o = new_o[::cnf_slice[1]]
old_new += [(old, new_o)]
......@@ -1042,12 +1058,12 @@ class ScanSaveMem(gof.Optimizer):
remove.append(node)
fgraph.replace_all_validate_remove(old_new,
remove,
reason='scan_save_mem')
reason='scanOp_save_mem')
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op,
scan_op.Scan)]
scan_op.Scan)]
for node in nodelist:
if not hasattr(node.op, '_scan_savemem_visited'):
self.process_node(fgraph, node)
......@@ -1230,7 +1246,7 @@ class ScanMerge(gof.Optimizer):
proposal = self.merge(subset)
fgraph.replace_all_validate_remove(proposal,
remove=subset,
reason='scan_merge')
reason='scanOp_merge')
def has_duplicates(l):
......@@ -1389,13 +1405,13 @@ def scan_merge_inouts(node):
# items scan is supposed to store for this nit_sot sequence
shapes.append(x)
tmp = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
na.outer_out_nit_sot = [map_nitsot_out(i, o, sh, seen)
for i, o, sh in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
na.outer_out_nit_sot,
shapes)]
seen = []
na.outer_out_sit_sot = [map_out(i, o, seen)
......@@ -1592,10 +1608,8 @@ class PushOutDot1(gof.Optimizer):
old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out))
old_new += zip(node.outputs[pos+1:], new_outs[pos:])
fgraph.replace_all_validate_remove(old_new,
remove = [node],
reason='PushOutDot1')
fgraph.replace_all_validate_remove(
old_new, remove=[node], reason='scan_pushout_dot1')
# I've added an equilibrium because later scan optimization in the sequence
......@@ -1612,7 +1626,7 @@ optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan')
optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan')
optdb.register('scanOp_make_inplace',
ScanInplaceOptimizer(typeConstructor=None,
gpu_flag=False),
gpu_flag=False),
75,
'fast_run',
'inplace',
......@@ -1628,6 +1642,7 @@ scan_seqopt1.register('scanOp_remove_constants_and_unused_inputs0',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
1,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
......@@ -1662,10 +1677,11 @@ scan_seqopt2.register('constant_folding_for_scan2',
'scan')
scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs0',
scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs1',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
2,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
......@@ -1684,12 +1700,14 @@ scan_seqopt2.register('scanop_remove_constants_and_unused_inputs2',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
5,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
scan_seqopt2.register('scanOp_merge_inouts',
opt.in2out(scan_merge_inouts, ignore_newtrees=True),
6,
'scan_merge_inouts',
'fast_run',
'scan')
......@@ -1707,5 +1725,6 @@ scan_seqopt2.register('scanOp_remove_constants_and_unused_inputs3',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees=True),
8,
'remove_constants_and_unused_inputs_scan',
'fast_run',
'scan')
......@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner)
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = []
for o in outs:
......
......@@ -1141,7 +1141,7 @@ class T_Scan(unittest.TestCase):
go_backwards=False)
gX, gY = tensor.grad(values[1].sum(), [x, y])
f = theano.function([c, x, y], [gX, gY],
allow_input_downcast=True)
allow_input_downcast=True)
# Check for runtime errors
f(numpy.int32(0), numpy.float32(1.), numpy.float32(.5))
......@@ -1545,6 +1545,12 @@ class T_Scan(unittest.TestCase):
x0 = theano.tensor.vector('x0')
y0 = theano.tensor.vector('y0')
W_in1.tag.test_value = vW_in1
u1.tag.test_value = v_u1
u2.tag.test_value = v_u2
x0.tag.test_value = v_x0
y0.tag.test_value = v_y0
def f_rnn_cmpl(u1_t,
u2_tm1,
u2_t,
......@@ -1553,33 +1559,46 @@ class T_Scan(unittest.TestCase):
y_tm1,
y_tm3,
W_in1):
return [theano.dot(u1_t, W_in1) + \
(u2_t + u2_tm1 * u2_tp1) * W_in2 + \
theano.dot(x_tm1, W),
return [theano.dot(u1_t, W_in1) +
(u2_t + u2_tm1 * u2_tp1) * W_in2 +
theano.dot(x_tm1, W),
(y_tm1 + y_tm3) * theano.dot(x_tm1, W_out),
theano.dot(u1_t, W_in1)]
cost, updates = scan_project_sum(
f_rnn_cmpl,
[u1, dict(input=u2, taps=[-1, 0, 1])],
[x0, dict(initial=y0, taps=[-1, -3]), None],
W_in1,
n_steps=None,
truncate_gradient=-1,
go_backwards=False)
vparams = [v_u1, v_u2, v_x0, v_y0, vW_in1]
params = [u1, u2, x0, y0, W_in1]
gparams = theano.tensor.grad(cost, params)
grad_fn = theano.function([u1, u2, x0, y0, W_in1],
gparams,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
cost_fn = theano.function([u1, u2, x0, y0, W_in1],
cost,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
# We change the compute_test_value[_opt] flag to run the
# assert in Scan.grad() of the new scan input sequence related
# to outer_mitsot_outs, outer_sitsot_outs and
# outer_nitsot_outs. This allow to test an old Scan bug.
old1 = theano.config.compute_test_value
old2 = theano.config.compute_test_value_opt
theano.config.compute_test_value = 'raise'
theano.config.compute_test_value_opt = 'raise'
try:
cost, updates = scan_project_sum(
f_rnn_cmpl,
[u1, dict(input=u2, taps=[-1, 0, 1])],
[x0, dict(initial=y0, taps=[-1, -3]), None],
W_in1,
n_steps=None,
truncate_gradient=-1,
go_backwards=False)
vparams = [v_u1, v_u2, v_x0, v_y0, vW_in1]
params = [u1, u2, x0, y0, W_in1]
gparams = theano.tensor.grad(cost, params)
grad_fn = theano.function([u1, u2, x0, y0, W_in1],
gparams,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
cost_fn = theano.function([u1, u2, x0, y0, W_in1],
cost,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
finally:
theano.config.compute_test_value = old1
theano.config.compute_test_value_opt = old2
num_grad = multiple_outputs_numeric_grad(cost_fn,
[v_u1,
......
......@@ -2543,7 +2543,7 @@ class Alloc(gof.Op):
#change.
return [gx] + [DisconnectedType()() for i in inputs[1:]]
def __call__(self, val, *shapes):
def __call__(self, val, *shapes, **kwargs):
"""
If the alloc would be useless, this function returns val.
......@@ -2554,7 +2554,7 @@ class Alloc(gof.Op):
If you always want an Alloc node, call make_node.
"""
ret = super(Alloc, self).__call__(val, *shapes)
ret = super(Alloc, self).__call__(val, *shapes, **kwargs)
try:
# It makes optimization difficult when useless allocs are thrown
# into the graph at every stage of optimization. This little logic
......
......@@ -49,14 +49,24 @@ theano.configparser.AddConfigVar('on_shape_error',
def out2in(*local_opts):
"""WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace)
def in2out(*local_opts, **kwargs):
"""WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
......@@ -384,10 +394,12 @@ def local_dimshuffle_lift(node):
input = node.inputs[0]
inode = input.owner
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
return inode.op.make_node(*[DimShuffle(input.type.broadcastable,
op.new_order,
op.inplace)(input) for input in
inode.inputs]).outputs
# Don't use make_node to have tag.test_value set.
ret = inode.op(*[DimShuffle(input.type.broadcastable,
op.new_order,
op.inplace)(input) for input in
inode.inputs], **dict(return_list=True))
return ret
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order]
......@@ -397,8 +409,9 @@ def local_dimshuffle_lift(node):
iinput.type.ndim):
return [iinput]
else:
return DimShuffle(iinput.type.broadcastable, new_order,
inplace).make_node(iinput).outputs
ret = DimShuffle(iinput.type.broadcastable, new_order,
inplace)(iinput, **dict(return_list=True))
return ret
@register_canonicalize
......@@ -437,8 +450,10 @@ def dimshuffle_as_view(node):
#Step 60 is the inplace optimization stage.
compile.optdb.register('dimshuffle_as_view',
TopoOptimizer(dimshuffle_as_view,
failure_callback=TopoOptimizer.warn_inplace), 60,
TopoOptimizer(
dimshuffle_as_view,
failure_callback=TopoOptimizer.warn_inplace),
60,
'fast_run', 'inplace')
register_canonicalize(local_dimshuffle_lift)
register_specialize(local_dimshuffle_lift)
......@@ -771,7 +786,8 @@ class ShapeFeature(object):
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one
else:
return Shape_i(i).make_node(r).outputs[0]
# Do not call make_node for test_value
return Shape_i(i)(r)
def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r"""
......@@ -970,9 +986,9 @@ class ShapeFeature(object):
# shape var -> graph v
for node in fgraph.toposort():
self.on_import(fgraph, node)
self.on_import(fgraph, node, reason='on_attach')
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.shape_of:
# this is a revert, not really an import
for r in node.outputs + node.inputs:
......@@ -1933,7 +1949,8 @@ def local_subtensor_merge(node):
sl_ins = Subtensor.collapse(
merged_slices,
lambda x: isinstance(x, T.Variable))
out = subtens.make_node(x, *sl_ins).outputs[0]
# Do not call make_node for test_value
out = subtens(x, *sl_ins)
return [out]
......@@ -4583,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else:
tmp_s_input.append(scalar.Scalar(
ii.dtype).make_variable())
tmp = scalar.Scalar(ii.dtype).make_variable()
try:
tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0]
except AttributeError:
pass
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input)
......@@ -4634,6 +4655,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s = s_inputs[inputs.index(i)]
else:
s = scalar.Scalar(i.dtype).make_variable()
try:
v = gof.op.get_test_value(i)
if v.size > 0:
s.tag.test_value = gof.op.get_test_value(i).flatten()[0]
except AttributeError:
pass
inputs.append(i)
s_inputs.append(s)
s_g.append(s)
......@@ -4667,7 +4695,8 @@ your code will run correctly, but may be slower.""")
C = scalar.Composite(s_inputs, [s_new_out])
#create the new node.
n = OP(C).make_node(*inputs)
#Do not call make_node to have test_value
n = OP(C)(*inputs).owner
assert len(n.outputs) == 1
assert node.outputs[0].dtype == n.outputs[0].dtype
......@@ -4728,9 +4757,11 @@ if config.tensor.local_elemwise_fusion:
_logger.debug("enabling optimization fusion elemwise in fast_run")
compile.optdb.register('elemwise_fusion',
FusionOptimizer(local_elemwise_fusion), 71.00,
'fast_run', 'fusion', 'local_elemwise_fusion')
'fast_run', 'fusion', 'local_elemwise_fusion',
'FusionOptimizer')
else:
_logger.debug("not enabling optimization fusion elemwise in fast_run")
compile.optdb.register('elemwise_fusion',
FusionOptimizer(local_elemwise_fusion), 71.00,
'fusion', 'local_elemwise_fusion')
'fusion', 'local_elemwise_fusion',
'FusionOptimizer')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论