提交 a351e3db authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1482 from nouiz/rnade

Scan crash fix
......@@ -482,6 +482,14 @@ import theano and print the config variable, as in:
This flag's value cannot be modified during the program execution.
.. attribute:: optimizer_verbose
Bool value: either True or False
Default: False
When True, we print on the stdout the optimization applied.
.. attribute:: nocleanup
Bool value: either True or False
......@@ -630,6 +638,12 @@ import theano and print the config variable, as in:
this Op
- ``'raise'`` will raise an Exception
.. attribute:: config.compute_test_value_opt
As ``compute_test_value``, but it is the value used during Theano
optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization.
.. attribute:: config.exception_verbosity
String Value: ``'low'``, ``'high'``.
......
......@@ -1428,21 +1428,25 @@ class _VariableEquivalenceTracker(object):
self.reasons = {}
self.replaced_by = {}
self.event_list = []
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph):
assert fgraph is self.fgraph
self.fgraph = None
def on_prune(self, fgraph, node):
self.event_list.append(_FunctionGraphEvent('prune', node))
def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('prune', node,
reason=reason))
#print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes
assert node not in self.inactive_nodes
self.active_nodes.remove(node)
self.inactive_nodes.add(node)
def on_import(self, fgraph, node):
self.event_list.append(_FunctionGraphEvent('import', node))
def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('import', node,
reason=reason))
#print 'NEW NODE', node, id(node)
assert node not in self.active_nodes
......@@ -2114,7 +2118,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# optimize the fgraph
compute_test_value_orig = theano.config.compute_test_value
try:
theano.config.compute_test_value = "off"
theano.config.compute_test_value = theano.config.compute_test_value_opt
optimizer(fgraph)
theano.compile.function_module.insert_deepcopy(fgraph, inputs,
......
......@@ -1018,7 +1018,7 @@ class FunctionMaker(object):
compute_test_value_orig = theano.config.compute_test_value
add_stack_trace_on_call = gof.Op.add_stack_trace_on_call
try:
theano.config.compute_test_value = "off"
theano.config.compute_test_value = theano.config.compute_test_value_opt
gof.Op.add_stack_trace_on_call = False
start_optimizer = time.time()
optimizer_profile = optimizer(fgraph)
......
......@@ -157,6 +157,11 @@ AddConfigVar('optimizer',
EnumStr('fast_run', 'merge', 'fast_compile', 'None'),
in_c_key=False)
AddConfigVar('optimizer_verbose',
"If True, we print all optimization being applied",
BoolParam(False),
in_c_key=False)
AddConfigVar('on_opt_error',
("What to do when an optimization crashes: warn and skip it, raise "
"the exception, or fall into the pdb debugger."),
......@@ -379,10 +384,17 @@ AddConfigVar('compute_test_value',
"Constants, SharedVariables and the tag 'test_value' as inputs "
"to the function. This helps the user track down problems in the "
"graph before it gets optimized."),
EnumStr('off', 'ignore', 'warn', 'raise'),
EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False)
AddConfigVar('compute_test_value_opt',
("For debugging Theano optimization only."
" Same as compute_test_value, but is used"
" during Theano optimization"),
EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False)
"""Note to developers:
Generally your exceptions should use an apply node's __str__
method when exception_verbosity == 'low'. When exception_verbosity
......
......@@ -380,7 +380,7 @@ if 0:
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def on_import(self, fgraph, app):
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
......@@ -410,7 +410,7 @@ if 0:
self.stale_droot = True
def on_prune(self, fgraph, app):
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app)
......@@ -765,7 +765,7 @@ class DestroyHandler(toolbox.Bookkeeper):
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def on_import(self, fgraph, app):
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import")
......@@ -795,7 +795,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self.stale_droot = True
def on_prune(self, fgraph, app):
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import")
self.debug_all_apps.remove(app)
......
差异被折叠。
......@@ -13,6 +13,7 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
import logging
import sys
import warnings
import theano
......@@ -408,6 +409,9 @@ class PureOp(object):
elif config.compute_test_value == 'ignore':
# silently skip test
run_perform = False
elif config.compute_test_value == 'pdb':
import pdb
pdb.post_mortem(sys.exc_info()[2])
else:
raise ValueError('%s is invalid for option config.compute_Test_value' % config.compute_test_value)
......@@ -638,8 +642,11 @@ def get_test_value(v):
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
v_tensor = theano.tensor.as_tensor_variable(v)
return PureOp._get_test_value(v_tensor)
if not isinstance(v, graph.Variable):
v_var = theano.tensor.as_tensor_variable(v)
else:
v_var = v
return PureOp._get_test_value(v_var)
def missing_test_message(msg):
......
......@@ -421,7 +421,7 @@ class MergeFeature(object):
self.blacklist = []
for node in fgraph.toposort():
self.on_import(fgraph, node)
self.on_import(fgraph, node, "on_attach")
def on_change_input(self, fgraph, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct
......@@ -433,14 +433,14 @@ class MergeFeature(object):
if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r)
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
for c in node.inputs:
if isinstance(c, graph.Constant):
self.process_constant(fgraph, c)
self.process_node(fgraph, node)
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
......@@ -548,7 +548,7 @@ class MergeOptimizer(Optimizer):
except InconsistencyError:
success = False
fgraph.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
(pairs[0][0].owner, pairs[0][1].owner))
if success:
break
......@@ -1027,7 +1027,7 @@ class PatternSub(LocalOptimizer):
else:
return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True,
self.pdb)
self.pdb)
if u:
p = self.out_pattern
new = build(p, u)
......@@ -1165,10 +1165,10 @@ class NavigatorOptimizer(Optimizer):
class Updater:
if importer is not None:
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
importer(node)
if pruner is not None:
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
pruner(node)
if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r):
......@@ -1357,7 +1357,7 @@ class ChangeTracker:
def __init__(self):
self.changed = False
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
self.changed = True
def on_change_input(self, fgraph, node, i, r, new_r):
......
import sys
import time
from theano import config
from theano.gof.python25 import partial
from theano.gof.python25 import OrderedDict
from theano.gof import graph
class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
......@@ -57,7 +56,7 @@ class Feature(object):
functionality that it installed into the function_graph.
"""
def on_import(self, function_graph, node):
def on_import(self, function_graph, node, reason):
"""
Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph.
......@@ -66,7 +65,7 @@ class Feature(object):
you should do this by implementing on_attach.
"""
def on_prune(self, function_graph, node):
def on_prune(self, function_graph, node, reason):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
......@@ -98,11 +97,11 @@ class Bookkeeper(Feature):
def on_attach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_import(fgraph, node)
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph):
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_prune(fgraph, node)
self.on_prune(fgraph, node, 'Bookkeeper.detach')
class History(Feature):
......@@ -199,11 +198,14 @@ class ReplaceValidate(History, Validator):
def replace_validate(self, fgraph, r, new_r, reason=None):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason)
def replace_all_validate(self, fgraph, replacements, reason=None):
def replace_all_validate(self, fgraph, replacements,
reason=None, verbose=None):
chk = fgraph.checkpoint()
if verbose is None:
verbose = config.optimizer_verbose
for r, new_r in replacements:
try:
fgraph.replace(r, new_r, reason=reason)
fgraph.replace(r, new_r, reason=reason, verbose=False)
except Exception, e:
if ('The type of the replacement must be the same' not in
str(e) and 'does not belong to this FunctionGraph' not in str(e)):
......@@ -219,6 +221,8 @@ class ReplaceValidate(History, Validator):
except Exception, e:
fgraph.revert(chk)
raise
if verbose:
print reason, r, new_r
return chk
def replace_all_validate_remove(self, fgraph, replacements,
......@@ -267,7 +271,7 @@ class NodeFinder(dict, Bookkeeper):
del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph)
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
try:
self.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
......@@ -280,7 +284,7 @@ class NodeFinder(dict, Bookkeeper):
print >> sys.stderr, 'OFFENDING node not hashable'
raise e
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
try:
nodes = self[node.op]
except TypeError: # node.op is unhashable
......@@ -312,13 +316,13 @@ class PrintListener(Feature):
if self.active:
print "-- detaching from: ", fgraph
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
if self.active:
print "-- importing: %s" % node
print "-- importing: %s, reason: %s" % (node, reason)
def on_prune(self, fgraph, node):
def on_prune(self, fgraph, node, reason):
if self.active:
print "-- pruning: %s" % node
print "-- pruning: %s, reason: %s" % (node, reason)
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active:
......
......@@ -2953,7 +2953,6 @@ class GpuJoin(tensor.Join, GpuOp):
axis = inputs[0]
n_cndas = len(inputs[1:])
input_1 = inputs[1]
axis = inputs[0]
fail = sub['fail']
out = out_[0]
......
......@@ -137,9 +137,9 @@ class HintsFeature(object):
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.hints = {}
for node in fgraph.toposort():
self.on_import(fgraph, node)
self.on_import(fgraph, node, "on_attach")
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.hints:
# this is a revert, not really an import
for r in node.outputs + node.inputs:
......
......@@ -338,7 +338,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner)
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = []
for o in outs:
......
......@@ -183,6 +183,24 @@ class Scalar(Type):
def dtype_specs(self):
try:
# To help debug dtype/typenum problem, here is code to get
# the list of numpy typenum. This list change between 32
# and 64 bit platform and probably also also between
# Windows and Linux.
# NOTE: equivalent type on a platform can have different typenum.
# This is the source of all dtype/typenum problem found up to
# now, as Theano always expect the exact typenum that
# correspond to our supported dtype.
"""
for dtype in ['int8', 'uint8', 'short', 'ushort', 'intc', 'uintc',
'longlong', 'ulonglong', 'single', 'double',
'longdouble', 'csingle', 'cdouble', 'clongdouble',
'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'float', 'double',
'int', 'uint']:
print dtype, np.zeros(1, dtype=dtype).dtype.num
"""
return { # dtype: (py_type, c_type, cls_name)
'float32': (numpy.float32, 'npy_float32', 'Float32'),
'float64': (numpy.float64, 'npy_float64', 'Float64'),
......
......@@ -101,7 +101,7 @@ def scan(fn,
The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the same
as the order of ``output_info``. For any sequence or output the
as the order of ``outputs_info``. For any sequence or output the
order of the time slices is the same as the one in which they have
been given as taps. For example if one writes the following :
......@@ -262,7 +262,7 @@ def scan(fn,
outputs will have *0 rows*. If the value is negative, ``scan``
will run backwards in time. If the ``go_backwards`` flag is already
set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n stpes is not provided, ``scan`` will figure
in time. If n_steps is not provided, ``scan`` will figure
out the amount of steps it should run given its input sequences.
......@@ -817,7 +817,7 @@ def scan(fn,
if as_while:
tmp_dummy_f_outs -= 1
if not (tmp_dummy_f_outs == n_outs or outs_info == []):
raise ValueError('Please provide None as output_info for '
raise ValueError('Please provide None as outputs_info for '
'any output that does not feed back into '
'scan (i.e. it behaves like a map) ')
......
......@@ -1581,8 +1581,30 @@ class Scan(PureOp):
if not isinstance(x.type, DisconnectedType):
outer_inp_seqs.append(x[::-1])
outer_inp_seqs += [x[::-1] for x in self.outer_mitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_sitsot_outs(outs)]
if hasattr(inputs[0].tag, 'test_value'):
# Here we tests that the new scan input sequence all have
# the same shape[0]. This is a properties that the scan()
# fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test
# that.
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs)):
mintap = numpy.min(taps)
if hasattr(x[::-1][:mintap], 'test_value'):
assert (x[::-1][:mintap].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'):
assert (x[::-1][:-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, 'test_value'):
assert (x[::-1].tag.test_value.shape[0] ==
inputs[0].tag.test_value)
outer_inp_seqs += [x[::-1][:numpy.min(taps)]
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs))]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
inner_inp_seqs = self.inner_seqs(self_inputs)
......
......@@ -500,7 +500,7 @@ def infer_shape(outs, inputs, input_shapes):
# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner)
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
ret = []
for o in outs:
......
......@@ -1141,7 +1141,7 @@ class T_Scan(unittest.TestCase):
go_backwards=False)
gX, gY = tensor.grad(values[1].sum(), [x, y])
f = theano.function([c, x, y], [gX, gY],
allow_input_downcast=True)
allow_input_downcast=True)
# Check for runtime errors
f(numpy.int32(0), numpy.float32(1.), numpy.float32(.5))
......@@ -1545,6 +1545,12 @@ class T_Scan(unittest.TestCase):
x0 = theano.tensor.vector('x0')
y0 = theano.tensor.vector('y0')
W_in1.tag.test_value = vW_in1
u1.tag.test_value = v_u1
u2.tag.test_value = v_u2
x0.tag.test_value = v_x0
y0.tag.test_value = v_y0
def f_rnn_cmpl(u1_t,
u2_tm1,
u2_t,
......@@ -1553,33 +1559,46 @@ class T_Scan(unittest.TestCase):
y_tm1,
y_tm3,
W_in1):
return [theano.dot(u1_t, W_in1) + \
(u2_t + u2_tm1 * u2_tp1) * W_in2 + \
theano.dot(x_tm1, W),
return [theano.dot(u1_t, W_in1) +
(u2_t + u2_tm1 * u2_tp1) * W_in2 +
theano.dot(x_tm1, W),
(y_tm1 + y_tm3) * theano.dot(x_tm1, W_out),
theano.dot(u1_t, W_in1)]
cost, updates = scan_project_sum(
f_rnn_cmpl,
[u1, dict(input=u2, taps=[-1, 0, 1])],
[x0, dict(initial=y0, taps=[-1, -3]), None],
W_in1,
n_steps=None,
truncate_gradient=-1,
go_backwards=False)
vparams = [v_u1, v_u2, v_x0, v_y0, vW_in1]
params = [u1, u2, x0, y0, W_in1]
gparams = theano.tensor.grad(cost, params)
grad_fn = theano.function([u1, u2, x0, y0, W_in1],
gparams,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
cost_fn = theano.function([u1, u2, x0, y0, W_in1],
cost,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
# We change the compute_test_value[_opt] flag to run the
# assert in Scan.grad() of the new scan input sequence related
# to outer_mitsot_outs, outer_sitsot_outs and
# outer_nitsot_outs. This allow to test an old Scan bug.
old1 = theano.config.compute_test_value
old2 = theano.config.compute_test_value_opt
theano.config.compute_test_value = 'raise'
theano.config.compute_test_value_opt = 'raise'
try:
cost, updates = scan_project_sum(
f_rnn_cmpl,
[u1, dict(input=u2, taps=[-1, 0, 1])],
[x0, dict(initial=y0, taps=[-1, -3]), None],
W_in1,
n_steps=None,
truncate_gradient=-1,
go_backwards=False)
vparams = [v_u1, v_u2, v_x0, v_y0, vW_in1]
params = [u1, u2, x0, y0, W_in1]
gparams = theano.tensor.grad(cost, params)
grad_fn = theano.function([u1, u2, x0, y0, W_in1],
gparams,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
cost_fn = theano.function([u1, u2, x0, y0, W_in1],
cost,
updates=updates,
no_default_updates=True,
allow_input_downcast=True)
finally:
theano.config.compute_test_value = old1
theano.config.compute_test_value_opt = old2
num_grad = multiple_outputs_numeric_grad(cost_fn,
[v_u1,
......
......@@ -2543,7 +2543,7 @@ class Alloc(gof.Op):
#change.
return [gx] + [DisconnectedType()() for i in inputs[1:]]
def __call__(self, val, *shapes):
def __call__(self, val, *shapes, **kwargs):
"""
If the alloc would be useless, this function returns val.
......@@ -2554,7 +2554,7 @@ class Alloc(gof.Op):
If you always want an Alloc node, call make_node.
"""
ret = super(Alloc, self).__call__(val, *shapes)
ret = super(Alloc, self).__call__(val, *shapes, **kwargs)
try:
# It makes optimization difficult when useless allocs are thrown
# into the graph at every stage of optimization. This little logic
......
......@@ -49,14 +49,24 @@ theano.configparser.AddConfigVar('on_shape_error',
def out2in(*local_opts):
"""WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace)
def in2out(*local_opts, **kwargs):
"""WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
......@@ -384,10 +394,12 @@ def local_dimshuffle_lift(node):
input = node.inputs[0]
inode = input.owner
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
return inode.op.make_node(*[DimShuffle(input.type.broadcastable,
op.new_order,
op.inplace)(input) for input in
inode.inputs]).outputs
# Don't use make_node to have tag.test_value set.
ret = inode.op(*[DimShuffle(input.type.broadcastable,
op.new_order,
op.inplace)(input) for input in
inode.inputs], **dict(return_list=True))
return ret
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order]
......@@ -397,8 +409,9 @@ def local_dimshuffle_lift(node):
iinput.type.ndim):
return [iinput]
else:
return DimShuffle(iinput.type.broadcastable, new_order,
inplace).make_node(iinput).outputs
ret = DimShuffle(iinput.type.broadcastable, new_order,
inplace)(iinput, **dict(return_list=True))
return ret
@register_canonicalize
......@@ -437,8 +450,10 @@ def dimshuffle_as_view(node):
#Step 60 is the inplace optimization stage.
compile.optdb.register('dimshuffle_as_view',
TopoOptimizer(dimshuffle_as_view,
failure_callback=TopoOptimizer.warn_inplace), 60,
TopoOptimizer(
dimshuffle_as_view,
failure_callback=TopoOptimizer.warn_inplace),
60,
'fast_run', 'inplace')
register_canonicalize(local_dimshuffle_lift)
register_specialize(local_dimshuffle_lift)
......@@ -771,7 +786,8 @@ class ShapeFeature(object):
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one
else:
return Shape_i(i).make_node(r).outputs[0]
# Do not call make_node for test_value
return Shape_i(i)(r)
def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r"""
......@@ -970,9 +986,9 @@ class ShapeFeature(object):
# shape var -> graph v
for node in fgraph.toposort():
self.on_import(fgraph, node)
self.on_import(fgraph, node, reason='on_attach')
def on_import(self, fgraph, node):
def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.shape_of:
# this is a revert, not really an import
for r in node.outputs + node.inputs:
......@@ -1933,7 +1949,8 @@ def local_subtensor_merge(node):
sl_ins = Subtensor.collapse(
merged_slices,
lambda x: isinstance(x, T.Variable))
out = subtens.make_node(x, *sl_ins).outputs[0]
# Do not call make_node for test_value
out = subtens(x, *sl_ins)
return [out]
......@@ -4583,8 +4600,12 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else:
tmp_s_input.append(scalar.Scalar(
ii.dtype).make_variable())
tmp = scalar.Scalar(ii.dtype).make_variable()
try:
tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0]
except AttributeError:
pass
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input)
......@@ -4634,6 +4655,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s = s_inputs[inputs.index(i)]
else:
s = scalar.Scalar(i.dtype).make_variable()
try:
v = gof.op.get_test_value(i)
if v.size > 0:
s.tag.test_value = gof.op.get_test_value(i).flatten()[0]
except AttributeError:
pass
inputs.append(i)
s_inputs.append(s)
s_g.append(s)
......@@ -4667,7 +4695,8 @@ your code will run correctly, but may be slower.""")
C = scalar.Composite(s_inputs, [s_new_out])
#create the new node.
n = OP(C).make_node(*inputs)
#Do not call make_node to have test_value
n = OP(C)(*inputs).owner
assert len(n.outputs) == 1
assert node.outputs[0].dtype == n.outputs[0].dtype
......@@ -4728,9 +4757,11 @@ if config.tensor.local_elemwise_fusion:
_logger.debug("enabling optimization fusion elemwise in fast_run")
compile.optdb.register('elemwise_fusion',
FusionOptimizer(local_elemwise_fusion), 71.00,
'fast_run', 'fusion', 'local_elemwise_fusion')
'fast_run', 'fusion', 'local_elemwise_fusion',
'FusionOptimizer')
else:
_logger.debug("not enabling optimization fusion elemwise in fast_run")
compile.optdb.register('elemwise_fusion',
FusionOptimizer(local_elemwise_fusion), 71.00,
'fusion', 'local_elemwise_fusion')
'fusion', 'local_elemwise_fusion',
'FusionOptimizer')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论