提交 0294429e authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

More work on issue #3018.

上级 8571cb47
......@@ -3402,8 +3402,21 @@ def local_rebroadcast_lift(node):
# by the `unbroadcast` function before we are in the actual function
# compilation phase.
if hasattr(input, 'clients') and len(input.clients) == 1:
rval = inode.op.make_node(T.Rebroadcast(*list(op.axis.items()))(
inode.inputs[0])).outputs
rebroadcasted = T.Rebroadcast(*list(op.axis.items()))(
inode.inputs[0])
# Copy over stacktrace from previous output (after rebroadcasting)
# to new output, because an error in the new graph right after
# rebroadcasting must have been caused by the previous rebroadcasting.
copy_stack_trace(node.outputs, rebroadcasted)
rval = inode.op.make_node(rebroadcasted).outputs
# Copy over stacktrace from previous output (after rebroadcasting)
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# two ops.
copy_stack_trace(node.outputs+node.inputs, rval)
return rval
if inode and isinstance(inode.op, T.Rebroadcast):
# the "axis" specification in the outer Rebroadcast overrides
......@@ -3411,7 +3424,14 @@ def local_rebroadcast_lift(node):
axis = inode.op.axis.copy()
axis.update(op.axis)
iinput = inode.inputs[0]
rval = [T.Rebroadcast(*list(axis.items()))(iinput)]
# Copy over stacktrace from previous output (after second rebroadcast)
# and from previous input (after first rebroadcast op) because an error in
# the new graph could have been caused by either of the two
# rebroadcast ops.
copy_stack_trace(node.outputs+node.inputs, rval)
return rval
......@@ -3465,6 +3485,8 @@ def local_join_1(node):
return
tensors = node.inputs[1:]
if len(tensors) == 1:
# We don't need to copy over any stacktrace here, because the
# input variable should already have its own stacktrace.
return [tensors[0]]
......@@ -3507,6 +3529,12 @@ def local_join_empty(node):
assert ret.dtype == o.dtype
assert ret.ndim == o.ndim
ret = T.patternbroadcast(ret, node.outputs[0].broadcastable)
# Copy over stacktrace from previous output (after join op)
# to new output, because an error in the new op must be caused
# by an error in the old join op.
copy_stack_trace(node.outputs, ret)
return [ret]
......@@ -3533,10 +3561,20 @@ def local_join_make_vector(node):
inp.owner.op == new_inputs[-1].owner.op):
inps = new_inputs[-1].owner.inputs + inp.owner.inputs
new_inputs[-1] = inp.owner.op(*inps)
# Copy over stacktrace from previous output (after join op)
# to new intermediate output, because an error in the intermediate
# op must be caused by an error in the old join op.
copy_stack_trace(node.outputs, new_inputs[-1])
else:
new_inputs.append(inp)
if len(new_inputs) < len(node.inputs) - 1:
ret = T.join(node.inputs[0], *new_inputs)
# Copy over stacktrace from previous output (after join op)
# to new output, because an error in the new op must be caused
# by an error in the old join op.
copy_stack_trace(node.outputs, ret)
return [ret]
......@@ -3562,25 +3600,33 @@ def local_useless_switch(node):
cond = T.extract_constant(node.inputs[0], elemwise=False)
if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0:
out = node.inputs[2]
correct_out = node.inputs[2]
else:
out = node.inputs[1]
correct_out = node.inputs[1]
if out.ndim != node.outputs[0].ndim:
if correct_out.ndim != node.outputs[0].ndim:
# TODO: broadcast?
return False
if out.dtype != node.outputs[0].dtype:
out = T.cast(out, node.outputs[0].dtype)
if out.type.broadcastable != node.outputs[0].type.broadcastable:
if correct_out.dtype != node.outputs[0].dtype:
out = T.cast(correct_out, node.outputs[0].dtype)
if correct_out.type.broadcastable != node.outputs[0].type.broadcastable:
# We need to copy data to the new dimensions during execution
out = T.alloc(out, *[node.outputs[0].shape[i] for i
in xrange(out.ndim)])
out = T.alloc(correct_out, *[node.outputs[0].shape[i] for i
in xrange(correct_out.ndim)])
else:
out = correct_out
# Copy over stacktrace from selected output to new output
copy_stack_trace(node.outputs+correct_out, out)
return [out]
# if left is right -> left
if node.inputs[1] is node.inputs[2]:
if cond.type == node.inputs[1].type:
return [node.inputs[1]]
return [T.fill(cond, node.inputs[1])]
ret = T.fill(cond, node.inputs[1])
# Copy over stacktrace from switch output and correct branch
copy_stack_trace(node.outputs+node.inputs[1], ret)
return [ret]
# This case happens with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
......@@ -3596,6 +3642,8 @@ def local_useless_switch(node):
T.extract_constant(left) == 0 and \
right is cond_var.owner.inputs[0]:
assert right.type == node.outputs[0].type
# No need to copy over stacktrace, because the right input node
# already has its own stacktrace
return [right]
return False
return False
......@@ -3636,9 +3684,24 @@ def local_mul_switch_sink(node):
if (get_scalar_constant_value(
switch.inputs[1], only_process_constants=True) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fmul = T.mul(*(listmul + [switch.inputs[2]]))
# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)
fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))]
fmul)]
fct[0].values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs+switch.outputs, fct)
return fct
except NotScalarConstantError:
pass
......@@ -3646,9 +3709,23 @@ def local_mul_switch_sink(node):
if (get_scalar_constant_value(
switch.inputs[2], only_process_constants=True) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fmul = T.mul(*(listmul + [switch.inputs[1]]))
# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)
fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)]
fmul, 0)]
fct[0].values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs+switch.outputs, fct)
return fct
except NotScalarConstantError:
pass
......@@ -3676,17 +3753,45 @@ def local_div_switch_sink(node):
switch = node.inputs[0].owner
try:
if get_scalar_constant_value(switch.inputs[1]) == 0.:
fdiv = op(switch.inputs[2], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)
fct = [T.switch(switch.inputs[0], 0,
op(switch.inputs[2], node.inputs[1]))]
fdiv)]
fct[0].values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs+switch.outputs, fct)
return fct
except NotScalarConstantError:
pass
try:
if get_scalar_constant_value(switch.inputs[2]) == 0.:
fdiv = op(switch.inputs[1], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)
fct = [T.switch(switch.inputs[0],
op(switch.inputs[1], node.inputs[1]), 0)]
fdiv, 0)]
fct[0].values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs+switch.outputs, fct)
return fct
except NotScalarConstantError:
pass
......@@ -3713,6 +3818,8 @@ def local_useless_tile(node):
try:
l = T.get_vector_length(node.inputs[1])
if l == node.inputs[0].ndim:
# No need to copy over any stacktrace as previous
# input variable already has a stacktrace
return [node.inputs[0]]
elif l < node.inputs[0].ndim:
# The Op don't support that case, so we can't
......@@ -3725,7 +3832,11 @@ def local_useless_tile(node):
return
x_nd = node.inputs[0].ndim
broad = ['x'] * (l - x_nd) + xrange(x_nd)
return [node.inputs[0].dimshuffle(broad)]
ret = node.inputs[0].dimshuffle(broad)
# Copy over stacktrace from previous output node,
# and from node before tiling operation.
copy_stack_trace(node.outputs+node.inputs[0], ret)
return [ret]
except ValueError:
return
except NotScalarConstantError:
......@@ -3749,6 +3860,9 @@ def local_useless_split(node):
x, axis, splits = node.inputs
out = assert_op(x, T.eq(splits.shape[0], 1))
out = assert_op(out, T.eq(x.shape[axis], splits[0]))
# Copy over stacktrace from previous output node.
copy_stack_trace(node.outputs, out)
return [out]
......
......@@ -4026,6 +4026,10 @@ class T_Tile(unittest.TestCase):
assert len(topo) == 1
assert isinstance(topo[0].op, compile.DeepCopyOp)
f(data)
# Check that stacktrace is copied over
self.assertTrue(hasattr(f.outputs[0].variable.tag, 'trace'))
self.assertTrue(len(f.outputs[0].variable.tag.trace)>0)
def speed_local_pow_specialize_range():
......@@ -5711,6 +5715,7 @@ def test_local_join_empty():
for n in e if isinstance(n.op, Join)])
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for matrix join(1,a)
empty_mat = numpy.asarray([[]], dtype=config.floatX)
m = tensor.matrix('m')
......@@ -5723,7 +5728,6 @@ def test_local_join_empty():
assert all([not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e if isinstance(n.op, Join)])
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for vector, vector, empty to matrix
# We can't optimize this case.
s = tensor.stack([a, a, empty_vec])
......@@ -5735,7 +5739,6 @@ def test_local_join_empty():
assert all([not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e if isinstance(n.op, Join)])
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for matrix join(0,a)
# We can't optimize this case.
s = join(0, m, numpy.asarray([[2.]], dtype=config.floatX), m)
......@@ -5747,6 +5750,20 @@ def test_local_join_empty():
assert all([not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e if isinstance(n.op, Join)])
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# Julian: we can enable the following test, once we
# remove default optimizations.
# When we set optimizer=None, no optimizations should be applied,
# but that's not the case now...
# test that optimizations keep stack trace
#mode = theano.compile.mode.Mode(optimizer=None).including('canonicalize_db').including("local_join_empty")
#empty_mat = numpy.asarray([[]], dtype=config.floatX)
#m = tensor.matrix('m')
#s = join(1, empty_mat, m, m, m)
#f = function([m], s, mode=mode)
#assert hasattr(f.outputs[0].variable.tag, 'trace')
#assert len(f.outputs[0].variable.tag.trace) > 0
def test_local_join_make_vector():
......@@ -5765,6 +5782,10 @@ def test_local_join_make_vector():
assert f.maker.fgraph.outputs[0].dtype == config.floatX
print(f.outputs[0].variable.tag)
print(f.outputs[0].variable.tag.trace)
def test_local_add_specialize():
# test of non-zero dimension
a = tensor.vector()
......@@ -5864,6 +5885,12 @@ def test_local_useless_split():
assert len(graph_nonopt)==1
assert isinstance(graph_nonopt[0].op, tensor.Split)
# Check that stacktraces have been copied over properly
assert hasattr(f_opt.outputs[0].variable.tag, 'trace')
assert len(f_opt.outputs[0].variable.tag.trace) > 0
assert hasattr(f_nonopt.outputs[0].variable.tag, 'trace')
assert len(f_nonopt.outputs[0].variable.tag.trace) > 0
def test_local_flatten_lift():
for i in xrange(1, 4):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论