提交 a351e3db authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1482 from nouiz/rnade

Scan crash fix
...@@ -482,6 +482,14 @@ import theano and print the config variable, as in: ...@@ -482,6 +482,14 @@ import theano and print the config variable, as in:
This flag's value cannot be modified during the program execution. This flag's value cannot be modified during the program execution.
.. attribute:: optimizer_verbose
Bool value: either True or False
Default: False
When True, we print on the stdout the optimization applied.
.. attribute:: nocleanup .. attribute:: nocleanup
Bool value: either True or False Bool value: either True or False
...@@ -630,6 +638,12 @@ import theano and print the config variable, as in: ...@@ -630,6 +638,12 @@ import theano and print the config variable, as in:
this Op this Op
- ``'raise'`` will raise an Exception - ``'raise'`` will raise an Exception
.. attribute:: config.compute_test_value_opt
As ``compute_test_value``, but it is the value used during Theano
optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization.
.. attribute:: config.exception_verbosity .. attribute:: config.exception_verbosity
String Value: ``'low'``, ``'high'``. String Value: ``'low'``, ``'high'``.
......
...@@ -1428,21 +1428,25 @@ class _VariableEquivalenceTracker(object): ...@@ -1428,21 +1428,25 @@ class _VariableEquivalenceTracker(object):
self.reasons = {} self.reasons = {}
self.replaced_by = {} self.replaced_by = {}
self.event_list = [] self.event_list = []
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
assert fgraph is self.fgraph assert fgraph is self.fgraph
self.fgraph = None self.fgraph = None
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('prune', node)) self.event_list.append(_FunctionGraphEvent('prune', node,
reason=reason))
#print 'PRUNING NODE', node, id(node) #print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes assert node in self.active_nodes
assert node not in self.inactive_nodes assert node not in self.inactive_nodes
self.active_nodes.remove(node) self.active_nodes.remove(node)
self.inactive_nodes.add(node) self.inactive_nodes.add(node)
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('import', node)) self.event_list.append(_FunctionGraphEvent('import', node,
reason=reason))
#print 'NEW NODE', node, id(node) #print 'NEW NODE', node, id(node)
assert node not in self.active_nodes assert node not in self.active_nodes
...@@ -2114,7 +2118,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2114,7 +2118,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# optimize the fgraph # optimize the fgraph
compute_test_value_orig = theano.config.compute_test_value compute_test_value_orig = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = "off" theano.config.compute_test_value = theano.config.compute_test_value_opt
optimizer(fgraph) optimizer(fgraph)
theano.compile.function_module.insert_deepcopy(fgraph, inputs, theano.compile.function_module.insert_deepcopy(fgraph, inputs,
......
...@@ -1018,7 +1018,7 @@ class FunctionMaker(object): ...@@ -1018,7 +1018,7 @@ class FunctionMaker(object):
compute_test_value_orig = theano.config.compute_test_value compute_test_value_orig = theano.config.compute_test_value
add_stack_trace_on_call = gof.Op.add_stack_trace_on_call add_stack_trace_on_call = gof.Op.add_stack_trace_on_call
try: try:
theano.config.compute_test_value = "off" theano.config.compute_test_value = theano.config.compute_test_value_opt
gof.Op.add_stack_trace_on_call = False gof.Op.add_stack_trace_on_call = False
start_optimizer = time.time() start_optimizer = time.time()
optimizer_profile = optimizer(fgraph) optimizer_profile = optimizer(fgraph)
......
...@@ -157,6 +157,11 @@ AddConfigVar('optimizer', ...@@ -157,6 +157,11 @@ AddConfigVar('optimizer',
EnumStr('fast_run', 'merge', 'fast_compile', 'None'), EnumStr('fast_run', 'merge', 'fast_compile', 'None'),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_verbose',
"If True, we print all optimization being applied",
BoolParam(False),
in_c_key=False)
AddConfigVar('on_opt_error', AddConfigVar('on_opt_error',
("What to do when an optimization crashes: warn and skip it, raise " ("What to do when an optimization crashes: warn and skip it, raise "
"the exception, or fall into the pdb debugger."), "the exception, or fall into the pdb debugger."),
...@@ -379,10 +384,17 @@ AddConfigVar('compute_test_value', ...@@ -379,10 +384,17 @@ AddConfigVar('compute_test_value',
"Constants, SharedVariables and the tag 'test_value' as inputs " "Constants, SharedVariables and the tag 'test_value' as inputs "
"to the function. This helps the user track down problems in the " "to the function. This helps the user track down problems in the "
"graph before it gets optimized."), "graph before it gets optimized."),
EnumStr('off', 'ignore', 'warn', 'raise'), EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False) in_c_key=False)
AddConfigVar('compute_test_value_opt',
("For debugging Theano optimization only."
" Same as compute_test_value, but is used"
" during Theano optimization"),
EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False)
"""Note to developers: """Note to developers:
Generally your exceptions should use an apply node's __str__ Generally your exceptions should use an apply node's __str__
method when exception_verbosity == 'low'. When exception_verbosity method when exception_verbosity == 'low'. When exception_verbosity
......
...@@ -380,7 +380,7 @@ if 0: ...@@ -380,7 +380,7 @@ if 0:
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def on_import(self, fgraph, app): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import") #if app in self.debug_all_apps: raise ProtocolError("double import")
...@@ -410,7 +410,7 @@ if 0: ...@@ -410,7 +410,7 @@ if 0:
self.stale_droot = True self.stale_droot = True
def on_prune(self, fgraph, app): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import") #if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app) #self.debug_all_apps.remove(app)
...@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper):
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
def on_import(self, fgraph, app): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import") if app in self.debug_all_apps: raise ProtocolError("double import")
...@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self.stale_droot = True self.stale_droot = True
def on_prune(self, fgraph, app): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import") if app not in self.debug_all_apps: raise ProtocolError("prune without import")
self.debug_all_apps.remove(app) self.debug_all_apps.remove(app)
......
差异被折叠。
...@@ -13,6 +13,7 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>" ...@@ -13,6 +13,7 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import logging import logging
import sys
import warnings import warnings
import theano import theano
...@@ -408,6 +409,9 @@ class PureOp(object): ...@@ -408,6 +409,9 @@ class PureOp(object):
elif config.compute_test_value == 'ignore': elif config.compute_test_value == 'ignore':
# silently skip test # silently skip test
run_perform = False run_perform = False
elif config.compute_test_value == 'pdb':
import pdb
pdb.post_mortem(sys.exc_info()[2])
else: else:
raise ValueError('%s is invalid for option config.compute_Test_value' % config.compute_test_value) raise ValueError('%s is invalid for option config.compute_Test_value' % config.compute_test_value)
...@@ -638,8 +642,11 @@ def get_test_value(v): ...@@ -638,8 +642,11 @@ def get_test_value(v):
For a Shared variable, it is the internal value. For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value. For another Variable, it is the content of v.tag.test_value.
""" """
v_tensor = theano.tensor.as_tensor_variable(v) if not isinstance(v, graph.Variable):
return PureOp._get_test_value(v_tensor) v_var = theano.tensor.as_tensor_variable(v)
else:
v_var = v
return PureOp._get_test_value(v_var)
def missing_test_message(msg): def missing_test_message(msg):
......
...@@ -421,7 +421,7 @@ class MergeFeature(object): ...@@ -421,7 +421,7 @@ class MergeFeature(object):
self.blacklist = [] self.blacklist = []
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node) self.on_import(fgraph, node, "on_attach")
def on_change_input(self, fgraph, node, i, r, new_r): def on_change_input(self, fgraph, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct # If inputs to node change, it is not guaranteed that it is distinct
...@@ -433,14 +433,14 @@ class MergeFeature(object): ...@@ -433,14 +433,14 @@ class MergeFeature(object):
if isinstance(new_r, graph.Constant): if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r) self.process_constant(fgraph, new_r)
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
for c in node.inputs: for c in node.inputs:
if isinstance(c, graph.Constant): if isinstance(c, graph.Constant):
self.process_constant(fgraph, c) self.process_constant(fgraph, c)
self.process_node(fgraph, node) self.process_node(fgraph, node)
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node) self.nodes_seen.discard(node)
for c in node.inputs: for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1): if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
...@@ -548,7 +548,7 @@ class MergeOptimizer(Optimizer): ...@@ -548,7 +548,7 @@ class MergeOptimizer(Optimizer):
except InconsistencyError: except InconsistencyError:
success = False success = False
fgraph.merge_feature.blacklist.append( fgraph.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner)) (pairs[0][0].owner, pairs[0][1].owner))
if success: if success:
break break
...@@ -1027,7 +1027,7 @@ class PatternSub(LocalOptimizer): ...@@ -1027,7 +1027,7 @@ class PatternSub(LocalOptimizer):
else: else:
return pattern.clone() return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True, u = match(self.in_pattern, node.out, unify.Unification(), True,
self.pdb) self.pdb)
if u: if u:
p = self.out_pattern p = self.out_pattern
new = build(p, u) new = build(p, u)
...@@ -1165,10 +1165,10 @@ class NavigatorOptimizer(Optimizer): ...@@ -1165,10 +1165,10 @@ class NavigatorOptimizer(Optimizer):
class Updater: class Updater:
if importer is not None: if importer is not None:
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
importer(node) importer(node)
if pruner is not None: if pruner is not None:
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
pruner(node) pruner(node)
if chin is not None: if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r): def on_change_input(self, fgraph, node, i, r, new_r):
...@@ -1357,7 +1357,7 @@ class ChangeTracker: ...@@ -1357,7 +1357,7 @@ class ChangeTracker:
def __init__(self): def __init__(self):
self.changed = False self.changed = False
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
self.changed = True self.changed = True
def on_change_input(self, fgraph, node, i, r, new_r): def on_change_input(self, fgraph, node, i, r, new_r):
......
import sys import sys
import time import time
from theano import config
from theano.gof.python25 import partial from theano.gof.python25 import partial
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
from theano.gof import graph from theano.gof import graph
class AlreadyThere(Exception): class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph """Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical attempting to attach the feature already has a functionally identical
...@@ -57,7 +56,7 @@ class Feature(object): ...@@ -57,7 +56,7 @@ class Feature(object):
functionality that it installed into the function_graph. functionality that it installed into the function_graph.
""" """
def on_import(self, function_graph, node): def on_import(self, function_graph, node, reason):
""" """
Called whenever a node is imported into function_graph, which is Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph. just before the node is actually connected to the graph.
...@@ -66,7 +65,7 @@ class Feature(object): ...@@ -66,7 +65,7 @@ class Feature(object):
you should do this by implementing on_attach. you should do this by implementing on_attach.
""" """
def on_prune(self, function_graph, node): def on_prune(self, function_graph, node, reason):
""" """
Called whenever a node is pruned (removed) from the function_graph, Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph. after it is disconnected from the graph.
...@@ -98,11 +97,11 @@ class Bookkeeper(Feature): ...@@ -98,11 +97,11 @@ class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_import(fgraph, node) self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_prune(fgraph, node) self.on_prune(fgraph, node, 'Bookkeeper.detach')
class History(Feature): class History(Feature):
...@@ -199,11 +198,14 @@ class ReplaceValidate(History, Validator): ...@@ -199,11 +198,14 @@ class ReplaceValidate(History, Validator):
def replace_validate(self, fgraph, r, new_r, reason=None): def replace_validate(self, fgraph, r, new_r, reason=None):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason) self.replace_all_validate(fgraph, [(r, new_r)], reason=reason)
def replace_all_validate(self, fgraph, replacements, reason=None): def replace_all_validate(self, fgraph, replacements,
reason=None, verbose=None):
chk = fgraph.checkpoint() chk = fgraph.checkpoint()
if verbose is None:
verbose = config.optimizer_verbose
for r, new_r in replacements: for r, new_r in replacements:
try: try:
fgraph.replace(r, new_r, reason=reason) fgraph.replace(r, new_r, reason=reason, verbose=False)
except Exception, e: except Exception, e:
if ('The type of the replacement must be the same' not in if ('The type of the replacement must be the same' not in
str(e) and 'does not belong to this FunctionGraph' not in str(e)): str(e) and 'does not belong to this FunctionGraph' not in str(e)):
...@@ -219,6 +221,8 @@ class ReplaceValidate(History, Validator): ...@@ -219,6 +221,8 @@ class ReplaceValidate(History, Validator):
except Exception, e: except Exception, e:
fgraph.revert(chk) fgraph.revert(chk)
raise raise
if verbose:
print reason, r, new_r
return chk return chk
def replace_all_validate_remove(self, fgraph, replacements, def replace_all_validate_remove(self, fgraph, replacements,
...@@ -267,7 +271,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -267,7 +271,7 @@ class NodeFinder(dict, Bookkeeper):
del fgraph.get_nodes del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph) Bookkeeper.on_detach(self, fgraph)
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
try: try:
self.setdefault(node.op, []).append(node) self.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
...@@ -280,7 +284,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -280,7 +284,7 @@ class NodeFinder(dict, Bookkeeper):
print >> sys.stderr, 'OFFENDING node not hashable' print >> sys.stderr, 'OFFENDING node not hashable'
raise e raise e
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
try: try:
nodes = self[node.op] nodes = self[node.op]
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
...@@ -312,13 +316,13 @@ class PrintListener(Feature): ...@@ -312,13 +316,13 @@ class PrintListener(Feature):
if self.active: if self.active:
print "-- detaching from: ", fgraph print "-- detaching from: ", fgraph
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
if self.active: if self.active:
print "-- importing: %s" % node print "-- importing: %s, reason: %s" % (node, reason)
def on_prune(self, fgraph, node): def on_prune(self, fgraph, node, reason):
if self.active: if self.active:
print "-- pruning: %s" % node print "-- pruning: %s, reason: %s" % (node, reason)
def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active: if self.active:
......
...@@ -2953,7 +2953,6 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2953,7 +2953,6 @@ class GpuJoin(tensor.Join, GpuOp):
axis = inputs[0] axis = inputs[0]
n_cndas = len(inputs[1:]) n_cndas = len(inputs[1:])
input_1 = inputs[1] input_1 = inputs[1]
axis = inputs[0]
fail = sub['fail'] fail = sub['fail']
out = out_[0] out = out_[0]
......
...@@ -137,9 +137,9 @@ class HintsFeature(object): ...@@ -137,9 +137,9 @@ class HintsFeature(object):
# Variable -> tuple(scalars) or None (All tensor vars map to tuple) # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.hints = {} self.hints = {}
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node) self.on_import(fgraph, node, "on_attach")
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.hints: if node.outputs[0] in self.hints:
# this is a revert, not really an import # this is a revert, not really an import
for r in node.outputs + node.inputs: for r in node.outputs + node.inputs:
......
...@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph # shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately # It will call infer_shape and set_shape appropriately
dummy_fgraph = None dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner) shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = [] ret = []
for o in outs: for o in outs:
......
...@@ -183,6 +183,24 @@ class Scalar(Type): ...@@ -183,6 +183,24 @@ class Scalar(Type):
def dtype_specs(self): def dtype_specs(self):
try: try:
# To help debug dtype/typenum problem, here is code to get
# the list of numpy typenum. This list change between 32
# and 64 bit platform and probably also also between
# Windows and Linux.
# NOTE: equivalent type on a platform can have different typenum.
# This is the source of all dtype/typenum problem found up to
# now, as Theano always expect the exact typenum that
# correspond to our supported dtype.
"""
for dtype in ['int8', 'uint8', 'short', 'ushort', 'intc', 'uintc',
'longlong', 'ulonglong', 'single', 'double',
'longdouble', 'csingle', 'cdouble', 'clongdouble',
'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'float', 'double',
'int', 'uint']:
print dtype, np.zeros(1, dtype=dtype).dtype.num
"""
return { # dtype: (py_type, c_type, cls_name) return { # dtype: (py_type, c_type, cls_name)
'float32': (numpy.float32, 'npy_float32', 'Float32'), 'float32': (numpy.float32, 'npy_float32', 'Float32'),
'float64': (numpy.float64, 'npy_float64', 'Float64'), 'float64': (numpy.float64, 'npy_float64', 'Float64'),
......
...@@ -101,7 +101,7 @@ def scan(fn, ...@@ -101,7 +101,7 @@ def scan(fn,
The order of the sequences is the same as the one in the list The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the same `sequences` given to scan. The order of the outputs is the same
as the order of ``output_info``. For any sequence or output the as the order of ``outputs_info``. For any sequence or output the
order of the time slices is the same as the one in which they have order of the time slices is the same as the one in which they have
been given as taps. For example if one writes the following : been given as taps. For example if one writes the following :
...@@ -262,7 +262,7 @@ def scan(fn, ...@@ -262,7 +262,7 @@ def scan(fn,
outputs will have *0 rows*. If the value is negative, ``scan`` outputs will have *0 rows*. If the value is negative, ``scan``
will run backwards in time. If the ``go_backwards`` flag is already will run backwards in time. If the ``go_backwards`` flag is already
set and also ``n_steps`` is negative, ``scan`` will run forward set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n stpes is not provided, ``scan`` will figure in time. If n_steps is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences. out the amount of steps it should run given its input sequences.
...@@ -817,7 +817,7 @@ def scan(fn, ...@@ -817,7 +817,7 @@ def scan(fn,
if as_while: if as_while:
tmp_dummy_f_outs -= 1 tmp_dummy_f_outs -= 1
if not (tmp_dummy_f_outs == n_outs or outs_info == []): if not (tmp_dummy_f_outs == n_outs or outs_info == []):
raise ValueError('Please provide None as output_info for ' raise ValueError('Please provide None as outputs_info for '
'any output that does not feed back into ' 'any output that does not feed back into '
'scan (i.e. it behaves like a map) ') 'scan (i.e. it behaves like a map) ')
......
...@@ -1581,8 +1581,30 @@ class Scan(PureOp): ...@@ -1581,8 +1581,30 @@ class Scan(PureOp):
if not isinstance(x.type, DisconnectedType): if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[::-1]) outer_inp_seqs.append(x[::-1])
outer_inp_seqs += [x[::-1] for x in self.outer_mitsot_outs(outs)] if hasattr(inputs[0].tag, 'test_value'):
outer_inp_seqs += [x[::-1] for x in self.outer_sitsot_outs(outs)] # Here we tests that the new scan input sequence all have
# the same shape[0]. This is a properties that the scan()
# fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test
# that.
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs)):
mintap = numpy.min(taps)
if hasattr(x[::-1][:mintap], 'test_value'):
assert (x[::-1][:mintap].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'):
assert (x[::-1][:-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, 'test_value'):
assert (x[::-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
outer_inp_seqs += [x[::-1][:numpy.min(taps)]
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs))]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
inner_inp_seqs = self.inner_seqs(self_inputs) inner_inp_seqs = self.inner_seqs(self_inputs)
......
...@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph # shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately # It will call infer_shape and set_shape appropriately
dummy_fgraph = None dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner) shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = [] ret = []
for o in outs: for o in outs:
......
...@@ -1141,7 +1141,7 @@ class T_Scan(unittest.TestCase): ...@@ -1141,7 +1141,7 @@ class T_Scan(unittest.TestCase):
go_backwards=False) go_backwards=False)
gX, gY = tensor.grad(values[1].sum(), [x, y]) gX, gY = tensor.grad(values[1].sum(), [x, y])
f = theano.function([c, x, y], [gX, gY], f = theano.function([c, x, y], [gX, gY],
allow_input_downcast=True) allow_input_downcast=True)
# Check for runtime errors # Check for runtime errors
f(numpy.int32(0), numpy.float32(1.), numpy.float32(.5)) f(numpy.int32(0), numpy.float32(1.), numpy.float32(.5))
...@@ -1545,6 +1545,12 @@ class T_Scan(unittest.TestCase): ...@@ -1545,6 +1545,12 @@ class T_Scan(unittest.TestCase):
x0 = theano.tensor.vector('x0') x0 = theano.tensor.vector('x0')
y0 = theano.tensor.vector('y0') y0 = theano.tensor.vector('y0')
W_in1.tag.test_value = vW_in1
u1.tag.test_value = v_u1
u2.tag.test_value = v_u2
x0.tag.test_value = v_x0
y0.tag.test_value = v_y0
def f_rnn_cmpl(u1_t, def f_rnn_cmpl(u1_t,
u2_tm1, u2_tm1,
u2_t, u2_t,
...@@ -1553,33 +1559,46 @@ class T_Scan(unittest.TestCase): ...@@ -1553,33 +1559,46 @@ class T_Scan(unittest.TestCase):
y_tm1, y_tm1,
y_tm3, y_tm3,
W_in1): W_in1):
return [theano.dot(u1_t, W_in1) + \ return [theano.dot(u1_t, W_in1) +
(u2_t + u2_tm1 * u2_tp1) * W_in2 + \ (u2_t + u2_tm1 * u2_tp1) * W_in2 +
theano.dot(x_tm1, W), theano.dot(x_tm1, W),
(y_tm1 + y_tm3) * theano.dot(x_tm1, W_out), (y_tm1 + y_tm3) * theano.dot(x_tm1, W_out),
theano.dot(u1_t, W_in1)] theano.dot(u1_t, W_in1)]
cost, updates = scan_project_sum(
f_rnn_cmpl,
[u1, dict(input=u2, taps=[-1, 0, 1])],
[x0, dict(initial=y0, taps=[-1, -3]), None],
W_in1,
n_steps=None,
truncate_gradient=-1,
go_backwards=False)
vparams = [v_u1, v_u2, v_x0, v_y0, vW_in1]
params = [u1, u2, x0, y0, W_in1]
gparams = theano.tensor.grad(cost, params)
grad_fn = theano.function([u1, u2, x0, y0, W_in1],
gparams,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
cost_fn = theano.function([u1, u2, x0, y0, W_in1], # We change the compute_test_value[_opt] flag to run the
cost, # assert in Scan.grad() of the new scan input sequence related
updates=updates, # to outer_mitsot_outs, outer_sitsot_outs and
no_default_updates=True, # outer_nitsot_outs. This allow to test an old Scan bug.
allow_input_downcast=True) old1 = theano.config.compute_test_value
old2 = theano.config.compute_test_value_opt
theano.config.compute_test_value = 'raise'
theano.config.compute_test_value_opt = 'raise'
try:
cost, updates = scan_project_sum(
f_rnn_cmpl,
[u1, dict(input=u2, taps=[-1, 0, 1])],
[x0, dict(initial=y0, taps=[-1, -3]), None],
W_in1,
n_steps=None,
truncate_gradient=-1,
go_backwards=False)
vparams = [v_u1, v_u2, v_x0, v_y0, vW_in1]
params = [u1, u2, x0, y0, W_in1]
gparams = theano.tensor.grad(cost, params)
grad_fn = theano.function([u1, u2, x0, y0, W_in1],
gparams,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
cost_fn = theano.function([u1, u2, x0, y0, W_in1],
cost,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
finally:
theano.config.compute_test_value = old1
theano.config.compute_test_value_opt = old2
num_grad = multiple_outputs_numeric_grad(cost_fn, num_grad = multiple_outputs_numeric_grad(cost_fn,
[v_u1, [v_u1,
......
...@@ -2543,7 +2543,7 @@ class Alloc(gof.Op): ...@@ -2543,7 +2543,7 @@ class Alloc(gof.Op):
#change. #change.
return [gx] + [DisconnectedType()() for i in inputs[1:]] return [gx] + [DisconnectedType()() for i in inputs[1:]]
def __call__(self, val, *shapes): def __call__(self, val, *shapes, **kwargs):
""" """
If the alloc would be useless, this function returns val. If the alloc would be useless, this function returns val.
...@@ -2554,7 +2554,7 @@ class Alloc(gof.Op): ...@@ -2554,7 +2554,7 @@ class Alloc(gof.Op):
If you always want an Alloc node, call make_node. If you always want an Alloc node, call make_node.
""" """
ret = super(Alloc, self).__call__(val, *shapes) ret = super(Alloc, self).__call__(val, *shapes, **kwargs)
try: try:
# It makes optimization difficult when useless allocs are thrown # It makes optimization difficult when useless allocs are thrown
# into the graph at every stage of optimization. This little logic # into the graph at every stage of optimization. This little logic
......
...@@ -49,14 +49,24 @@ theano.configparser.AddConfigVar('on_shape_error', ...@@ -49,14 +49,24 @@ theano.configparser.AddConfigVar('on_shape_error',
def out2in(*local_opts): def out2in(*local_opts):
"""WRITEME """ """WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='out_to_in', order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace) failure_callback=TopoOptimizer.warn_inplace)
def in2out(*local_opts, **kwargs): def in2out(*local_opts, **kwargs):
"""WRITEME """ """WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='in_to_out', order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace, failure_callback=TopoOptimizer.warn_inplace,
**kwargs) **kwargs)
...@@ -384,10 +394,12 @@ def local_dimshuffle_lift(node): ...@@ -384,10 +394,12 @@ def local_dimshuffle_lift(node):
input = node.inputs[0] input = node.inputs[0]
inode = input.owner inode = input.owner
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1): if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
return inode.op.make_node(*[DimShuffle(input.type.broadcastable, # Don't use make_node to have tag.test_value set.
op.new_order, ret = inode.op(*[DimShuffle(input.type.broadcastable,
op.inplace)(input) for input in op.new_order,
inode.inputs]).outputs op.inplace)(input) for input in
inode.inputs], **dict(return_list=True))
return ret
if inode and isinstance(inode.op, DimShuffle): if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order] op.new_order]
...@@ -397,8 +409,9 @@ def local_dimshuffle_lift(node): ...@@ -397,8 +409,9 @@ def local_dimshuffle_lift(node):
iinput.type.ndim): iinput.type.ndim):
return [iinput] return [iinput]
else: else:
return DimShuffle(iinput.type.broadcastable, new_order, ret = DimShuffle(iinput.type.broadcastable, new_order,
inplace).make_node(iinput).outputs inplace)(iinput, **dict(return_list=True))
return ret
@register_canonicalize @register_canonicalize
...@@ -437,8 +450,10 @@ def dimshuffle_as_view(node): ...@@ -437,8 +450,10 @@ def dimshuffle_as_view(node):
#Step 60 is the inplace optimization stage. #Step 60 is the inplace optimization stage.
compile.optdb.register('dimshuffle_as_view', compile.optdb.register('dimshuffle_as_view',
TopoOptimizer(dimshuffle_as_view, TopoOptimizer(
failure_callback=TopoOptimizer.warn_inplace), 60, dimshuffle_as_view,
failure_callback=TopoOptimizer.warn_inplace),
60,
'fast_run', 'inplace') 'fast_run', 'inplace')
register_canonicalize(local_dimshuffle_lift) register_canonicalize(local_dimshuffle_lift)
register_specialize(local_dimshuffle_lift) register_specialize(local_dimshuffle_lift)
...@@ -771,7 +786,8 @@ class ShapeFeature(object): ...@@ -771,7 +786,8 @@ class ShapeFeature(object):
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]: if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one return self.lscalar_one
else: else:
return Shape_i(i).make_node(r).outputs[0] # Do not call make_node for test_value
return Shape_i(i)(r)
def shape_tuple(self, r): def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r""" """Return a tuple of symbolic shape vars for tensor variable r"""
...@@ -970,9 +986,9 @@ class ShapeFeature(object): ...@@ -970,9 +986,9 @@ class ShapeFeature(object):
# shape var -> graph v # shape var -> graph v
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node) self.on_import(fgraph, node, reason='on_attach')
def on_import(self, fgraph, node): def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.shape_of: if node.outputs[0] in self.shape_of:
# this is a revert, not really an import # this is a revert, not really an import
for r in node.outputs + node.inputs: for r in node.outputs + node.inputs:
...@@ -1933,7 +1949,8 @@ def local_subtensor_merge(node): ...@@ -1933,7 +1949,8 @@ def local_subtensor_merge(node):
sl_ins = Subtensor.collapse( sl_ins = Subtensor.collapse(
merged_slices, merged_slices,
lambda x: isinstance(x, T.Variable)) lambda x: isinstance(x, T.Variable))
out = subtens.make_node(x, *sl_ins).outputs[0] # Do not call make_node for test_value
out = subtens(x, *sl_ins)
return [out] return [out]
...@@ -4583,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4583,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif ii in tmp_input: elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else: else:
tmp_s_input.append(scalar.Scalar( tmp = scalar.Scalar(ii.dtype).make_variable()
ii.dtype).make_variable()) try:
tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0]
except AttributeError:
pass
tmp_s_input.append(tmp)
tmp_input.append(ii) tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1]) tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input) s_op = i.owner.op.scalar_op(*tmp_s_input)
...@@ -4634,6 +4655,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4634,6 +4655,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s = s_inputs[inputs.index(i)] s = s_inputs[inputs.index(i)]
else: else:
s = scalar.Scalar(i.dtype).make_variable() s = scalar.Scalar(i.dtype).make_variable()
try:
v = gof.op.get_test_value(i)
if v.size > 0:
s.tag.test_value = gof.op.get_test_value(i).flatten()[0]
except AttributeError:
pass
inputs.append(i) inputs.append(i)
s_inputs.append(s) s_inputs.append(s)
s_g.append(s) s_g.append(s)
...@@ -4667,7 +4695,8 @@ your code will run correctly, but may be slower.""") ...@@ -4667,7 +4695,8 @@ your code will run correctly, but may be slower.""")
C = scalar.Composite(s_inputs, [s_new_out]) C = scalar.Composite(s_inputs, [s_new_out])
#create the new node. #create the new node.
n = OP(C).make_node(*inputs) #Do not call make_node to have test_value
n = OP(C)(*inputs).owner
assert len(n.outputs) == 1 assert len(n.outputs) == 1
assert node.outputs[0].dtype == n.outputs[0].dtype assert node.outputs[0].dtype == n.outputs[0].dtype
...@@ -4728,9 +4757,11 @@ if config.tensor.local_elemwise_fusion: ...@@ -4728,9 +4757,11 @@ if config.tensor.local_elemwise_fusion:
_logger.debug("enabling optimization fusion elemwise in fast_run") _logger.debug("enabling optimization fusion elemwise in fast_run")
compile.optdb.register('elemwise_fusion', compile.optdb.register('elemwise_fusion',
FusionOptimizer(local_elemwise_fusion), 71.00, FusionOptimizer(local_elemwise_fusion), 71.00,
'fast_run', 'fusion', 'local_elemwise_fusion') 'fast_run', 'fusion', 'local_elemwise_fusion',
'FusionOptimizer')
else: else:
_logger.debug("not enabling optimization fusion elemwise in fast_run") _logger.debug("not enabling optimization fusion elemwise in fast_run")
compile.optdb.register('elemwise_fusion', compile.optdb.register('elemwise_fusion',
FusionOptimizer(local_elemwise_fusion), 71.00, FusionOptimizer(local_elemwise_fusion), 71.00,
'fusion', 'local_elemwise_fusion') 'fusion', 'local_elemwise_fusion',
'FusionOptimizer')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论