提交 95fba1b6 authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Merged previous work implementing stack trace copy over and tests for various ops. #3018

上级 4afefe16
...@@ -4216,7 +4216,18 @@ def local_flatten_lift(node): ...@@ -4216,7 +4216,18 @@ def local_flatten_lift(node):
isinstance(node.inputs[0].owner.op, T.Elemwise) and isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1): len(node.inputs[0].owner.inputs) == 1):
f = node.op(node.inputs[0].owner.inputs[0]) f = node.op(node.inputs[0].owner.inputs[0])
# Copy over stacktrace from previous output node (flatten op),
# since this is the op which may cause an error for f.
copy_stack_trace(node.outputs, f)
e = node.inputs[0].owner.op(f) e = node.inputs[0].owner.op(f)
# Copy over stacktrace from previous output node and from unary
# elementwise output node since if there was an error, it would
# probably have come from that operation.
copy_stack_trace(node.outputs + node.inputs[0], e)
return [e] return [e]
################## ##################
...@@ -4237,6 +4248,12 @@ def local_reshape_chain(op): ...@@ -4237,6 +4248,12 @@ def local_reshape_chain(op):
# TODO: this can permit a failing program to run by eliminating # TODO: this can permit a failing program to run by eliminating
# the lower reshape # the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a # It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is # broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the # when originally, we were able to figure out that one of the
...@@ -4275,6 +4292,62 @@ def local_useless_reshape(node): ...@@ -4275,6 +4292,62 @@ def local_useless_reshape(node):
output = node.outputs[0] output = node.outputs[0]
output_shape = node.inputs[1] output_shape = node.inputs[1]
if input.ndim != output.ndim:
return False
# Simple case: both input and output have a single dimension.
# This could hide errors if the user provides inconsistent shapes.
if (input.ndim == 1 and output.ndim == 1 and
input.broadcastable == output.broadcastable):
return [input]
# Second case: all the shapes match the input shape
# Match Reshape(x, x.shape)
if output_shape.owner and isinstance(output_shape.owner.op, Shape):
shape_input = output_shape.owner.inputs[0]
if shape_input == input:
return [input]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
output_shape_is = output_shape.owner.inputs
if not hasattr(node, 'fgraph'):
shape_feature = None
else:
shape_feature = getattr(node.fgraph, 'shape_feature', None)
shape_match = [False] * input.ndim
for dim in xrange(input.ndim):
outshp_i = output_shape_is[dim]
# Match Shape_i{dim}(input)
if (outshp_i.owner and isinstance(outshp_i.owner.op, Shape_i) and
outshp_i.owner.op.i == dim and
outshp_i.owner.inputs[0] == input):
shape_match[dim] = True
continue
# Match Shape(input)[dim]
if (outshp_i.owner and isinstance(outshp_i.owner.op, Subtensor) and
len(outshp_i.owner.inputs) == 2 and
extract_constant(outshp_i.owner.inputs[1]) == dim):
subtensor_inp = outshp_i.owner.inputs[0]
if (subtensor_inp.owner and
isinstance(subtensor_inp.owner.op, Shape)):
shape_input_i = subtensor_inp.owner.inputs[0]
if shape_input_i == input:
shape_match[dim] = True
continue
op = node.op
if not isinstance(op, Reshape):
return False
input = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
if input.ndim != output.ndim: if input.ndim != output.ndim:
return False return False
...@@ -4359,7 +4432,6 @@ def local_reshape_to_dimshuffle(node): ...@@ -4359,7 +4432,6 @@ def local_reshape_to_dimshuffle(node):
- reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,)) - reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1)) - reshape(x, (1, m, 1, n, 1, 1))
--> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n))) --> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n)))
""" """
op = node.op op = node.op
if not isinstance(op, Reshape): if not isinstance(op, Reshape):
...@@ -4408,16 +4480,33 @@ def local_reshape_lift(node): ...@@ -4408,16 +4480,33 @@ def local_reshape_lift(node):
isinstance(node.inputs[0].owner.op, T.Elemwise) and isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1): len(node.inputs[0].owner.inputs) == 1):
r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy stacktrace from previous Reshape op, as an error in new
# Reshape op could only have been caused by old one.
copy_stack_trace(node.outputs, r)
e = node.inputs[0].owner.op(r) e = node.inputs[0].owner.op(r)
# Copy stacktrace from both previous Reshape and UnaryElemwise op
# because an error in new cg could have been caused by either ops.
copy_stack_trace(node.outputs + node.inputs, e)
# In rare case the original broadcast was (False, True), but # In rare case the original broadcast was (False, True), but
# the new one is (False, False). So don't crash in that case. # the new one is (False, False). So don't crash in that case.
if e.type != node.outputs[0].type: if e.type != node.outputs[0].type:
e = T.patternbroadcast(e, node.outputs[0].broadcastable) re = T.patternbroadcast(e, node.outputs[0].broadcastable)
return [e]
# We assume that the broadcast op cannot fail. Thus, if the
# graph fails it must be due to previous UnaryElemwise op, and
# therefore we must copy its stacktrace over.
copy_stack_trace(e, re)
else:
re = e
return [re]
if 0: if 0:
# TODO: Test that this optimziation works. # TODO: Test that this optimziation works.
# TODO: Once it works, copy over stacktrace appropriately.
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Reshape]) @gof.local_optimizer([T.Reshape])
def local_scalar_reshape(node): def local_scalar_reshape(node):
...@@ -4434,6 +4523,7 @@ if 0: ...@@ -4434,6 +4523,7 @@ if 0:
# appropriately typed and broadcasted zero. # appropriately typed and broadcasted zero.
# TODO: Remember to take into account the new sum dtype argument if this # TODO: Remember to take into account the new sum dtype argument if this
# optimization is enabled. # optimization is enabled.
# TODO: Once it works, copy over stacktrace appropriately.
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Sum]) @gof.local_optimizer([T.Sum])
def local_sum_over_empty(node): def local_sum_over_empty(node):
...@@ -4465,11 +4555,11 @@ def local_fill_cut(node): ...@@ -4465,11 +4555,11 @@ def local_fill_cut(node):
If c.type == a.type. If c.type == a.type.
""" """
# this optimization is essentially for getting broadcasting to # this optimization is basically for getting broadcasting to
# replace fill. This is always possible when using a Compound # replace fill. This is always possible when using a Compound
# Elemwise operation, but it is not always possible without one # Elemwise operation, but it is not always possible without one
# (consider filling a large matrix with a scalar, and then adding # (consider filling a large matrix with a scalar, and then adding
# another scalar. The only numbers that count are the two # another scalar). The only numbers that count are the two
# scalars, but we can't ignore the large matrix because it gives # scalars, but we can't ignore the large matrix because it gives
# the shape of the result. # the shape of the result.
...@@ -4503,6 +4593,12 @@ def local_fill_cut(node): ...@@ -4503,6 +4593,12 @@ def local_fill_cut(node):
return False return False
rval = node.op(*new_inputs) rval = node.op(*new_inputs)
# Copy over stacktrace from previous elementwise op output.
# Since we are certain that an error in the cg can never come
# from the removed fill op, it must come from the elemntwise op.
copy_stack_trace(node.outputs, rval)
if isinstance(rval, gof.Variable): if isinstance(rval, gof.Variable):
return rval.owner.outputs return rval.owner.outputs
else: else:
...@@ -4966,6 +5062,10 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4966,6 +5062,10 @@ class Canonizer(gof.LocalOptimizer):
# This happen with test # This happen with test
# theano/tensor/tests/test_opt.py:T_local_switch_sink # theano/tensor/tests/test_opt.py:T_local_switch_sink
new.tag.values_eq_approx = values_eq_approx_remove_inf_nan new.tag.values_eq_approx = values_eq_approx_remove_inf_nan
# Julian: Pascal, maybe you can help me implement the copying of the stacktrace for this class?
# Because, it's so general I think we need to copy over the stacktraces of all ops being replaced
# to every new op?
return [new] return [new]
else: else:
_logger.warning(' '.join(('CANONIZE FAILED: new, out = ', _logger.warning(' '.join(('CANONIZE FAILED: new, out = ',
...@@ -5050,9 +5150,18 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -5050,9 +5150,18 @@ def local_sum_prod_mul_by_scalar(node):
new_op_output = node.op(non_scalars[0]) new_op_output = node.op(non_scalars[0])
else: else:
new_op_input = T.mul(*non_scalars) new_op_input = T.mul(*non_scalars)
# We assume that errors always come from the prod/mul op in the
# original computational graph, and therefore need to only
# copy over its output stacktrace.
copy_stack_trace(node.outputs, new_op_input)
new_op_input_nb_elements = new_op_input.size new_op_input_nb_elements = new_op_input.size
new_op_output = node.op(new_op_input) new_op_output = node.op(new_op_input)
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, new_op_output)
# If node.op is a T.elemwise.Prod, then the scalars need to be # If node.op is a T.elemwise.Prod, then the scalars need to be
# raised to the power of the number of elements in the input # raised to the power of the number of elements in the input
# to the Prod # to the Prod
...@@ -5068,12 +5177,28 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -5068,12 +5177,28 @@ def local_sum_prod_mul_by_scalar(node):
mul_inputs.append(new_op_output) mul_inputs.append(new_op_output)
if len(mul_inputs) == 1: if len(mul_inputs) == 1:
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, mul_inputs)
return mul_inputs return mul_inputs
else: else:
return [T.mul(*mul_inputs)] ret = T.mul(*mul_inputs)
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, ret+mul_inputs)
return [ret]
if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg: if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg:
return [T.neg(node.op(node_inps.owner.inputs[0]))] s = node.op(node_inps.owner.inputs[0])
ret = T.neg(s)
# There are never errors in the negative op, thus
# we need only to copy over stacktrace from previous output node to
# the two new ops.
copy_stack_trace(node.outputs, s+ret)
return [ret]
@register_specialize @register_specialize
...@@ -5086,7 +5211,11 @@ def local_elemwise_sub_zeros(node): ...@@ -5086,7 +5211,11 @@ def local_elemwise_sub_zeros(node):
node.op.scalar_op.nin == 2 and node.op.scalar_op.nin == 2 and
node.op.scalar_op == scalar.sub and node.op.scalar_op == scalar.sub and
node.inputs[0] == node.inputs[1]): node.inputs[0] == node.inputs[1]):
return [T.zeros_like(node.inputs[0])] res = T.zeros_like(node.inputs[0])
# Copy over stacktrace from previous output.
# Julian: Pascal, is this really necessary? Is there anyway zeros_like can ever fail?
copy_stack_trace(node.outputs, res)
return [res]
@register_useless @register_useless
...@@ -5133,54 +5262,77 @@ def local_useless_elemwise_comparison(node): ...@@ -5133,54 +5262,77 @@ def local_useless_elemwise_comparison(node):
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \ if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)] res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \ if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
return [T.ones_like(node.inputs[0], dtype=dtype, opt=True)] res = T.ones_like(node.inputs[0], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[{minimum,maximum}](X, X) -> X # Elemwise[{minimum,maximum}](X, X) -> X
if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \ if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
return [node.inputs[0]] res = node.inputs[0]
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X) # Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, scalar.LT) and \ if isinstance(node.op.scalar_op, scalar.LT) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)] res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=dtype, opt=True)] res = T.ones_like(node.inputs[0], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i] # Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \ if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
# No need to copy over stacktrace.
return [node.inputs[0]] return [node.inputs[0]]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i] # Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \ if isinstance(node.op.scalar_op, scalar.Maximum) and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \ T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \ node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i): isinstance(node.inputs[1].owner.op, Shape_i):
# No need to copy over stacktrace.
return [node.inputs[1]] return [node.inputs[1]]
# Elemwise[minimum](X.shape[i], 0) -> 0 # Elemwise[minimum](X.shape[i], 0) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)] res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# It don't detect case when the 0 is all zeros with ndim > 0.
# Elemwise[minimum](0, X.shape[i]) -> 0 # Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \ T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \ node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i): isinstance(node.inputs[1].owner.op, Shape_i):
return [T.zeros_like(node.inputs[1], dtype=dtype, opt=True)] res = T.zeros_like(node.inputs[1], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, scalar.LT) and \ if isinstance(node.op.scalar_op, scalar.LT) and \
...@@ -5190,8 +5342,10 @@ def local_useless_elemwise_comparison(node): ...@@ -5190,8 +5342,10 @@ def local_useless_elemwise_comparison(node):
all([isinstance(var.owner and var.owner.op, Shape_i) all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)] # Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
...@@ -5200,7 +5354,11 @@ def local_useless_elemwise_comparison(node): ...@@ -5200,7 +5354,11 @@ def local_useless_elemwise_comparison(node):
all([isinstance(var.owner and var.owner.op, Shape_i) all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=dtype, opt=True)] res = T.ones_like(node.inputs[0], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[EQ](Subtensor(Shape(x)), -N) # Elemwise[EQ](Subtensor(Shape(x)), -N)
# Elemwise[EQ](somegraph that only depend of shape, -N) # Elemwise[EQ](somegraph that only depend of shape, -N)
......
...@@ -3566,6 +3566,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3566,6 +3566,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
assert isinstance(elem.inputs[0], T.TensorConstant), elem assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val assert T.extract_constant(elem.inputs[0]) == val, val
def assert_identity(self, f): def assert_identity(self, f):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
...@@ -3661,6 +3662,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3661,6 +3662,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.eq(g, 0)) f = theano.function([x], T.eq(g, 0))
assert f([3, 3]) == 0 assert f([3, 3]) == 0
assert f([]) == 1 assert f([]) == 1
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
f = theano.function([x], T.eq(g, -1)) f = theano.function([x], T.eq(g, -1))
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
...@@ -3672,6 +3674,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3672,6 +3674,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.eq(g, 0)) f = theano.function([x], T.eq(g, 0))
assert (f([3, 3]) == 0).all() assert (f([3, 3]) == 0).all()
assert (f([]) == 1).all() assert (f([]) == 1).all()
self.assertTrue(check_stack_trace(f, ops_to_check='last'))
f = theano.function([x], T.eq(g, -1)) f = theano.function([x], T.eq(g, -1))
self.assert_eqs_const(f, 0, op=T.alloc) self.assert_eqs_const(f, 0, op=T.alloc)
...@@ -6291,11 +6294,17 @@ class Test_local_useless_reshape(unittest.TestCase): ...@@ -6291,11 +6294,17 @@ class Test_local_useless_reshape(unittest.TestCase):
topo = f1.maker.fgraph.toposort() topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo) assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check='all')
m2 = m1.excluding('ShapeOpt') m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2) f2 = theano.function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort() topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo) assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f2, ops_to_check='all')
def test_2(self): def test_2(self):
x = theano.tensor.matrix('x') x = theano.tensor.matrix('x')
r = x.reshape([Shape_i(i)(x) for i in xrange(x.ndim)]) r = x.reshape([Shape_i(i)(x) for i in xrange(x.ndim)])
...@@ -6306,11 +6315,17 @@ class Test_local_useless_reshape(unittest.TestCase): ...@@ -6306,11 +6315,17 @@ class Test_local_useless_reshape(unittest.TestCase):
topo = f1.maker.fgraph.toposort() topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo) assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check='all')
m2 = m1.excluding('ShapeOpt') m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2) f2 = theano.function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort() topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo) assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f2, ops_to_check='all')
class Test_local_reshape_to_dimshuffle(unittest.TestCase): class Test_local_reshape_to_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -6341,7 +6356,7 @@ class Test_local_reshape_to_dimshuffle(unittest.TestCase): ...@@ -6341,7 +6356,7 @@ class Test_local_reshape_to_dimshuffle(unittest.TestCase):
"TensorConstant{[5 6]}))]") "TensorConstant{[5 6]}))]")
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
check_stack_trace(g, ops_to_check=(T.DimShuffle, T.Reshape)) assert check_stack_trace(g, ops_to_check=(T.DimShuffle, T.Reshape))
def test_local_reshape_lift(): def test_local_reshape_lift():
...@@ -6355,7 +6370,8 @@ def test_local_reshape_lift(): ...@@ -6355,7 +6370,8 @@ def test_local_reshape_lift():
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert isinstance(topo[-2].op, tensor.Reshape) assert isinstance(topo[-2].op, tensor.Reshape)
assert isinstance(topo[-1].op, tensor.Elemwise) assert isinstance(topo[-1].op, tensor.Elemwise)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check='last')
class Test_lift_transpose_through_dot(unittest.TestCase): class Test_lift_transpose_through_dot(unittest.TestCase):
def simple_optimize(self, g): def simple_optimize(self, g):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论