提交 bde718ca authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Tried implementing tests for local_useless_reshape, but could not.

上级 ffd31201
......@@ -4226,7 +4226,7 @@ def local_flatten_lift(node):
# Copy over stacktrace from previous output node and from unary
# elementwise output node since if there was an error, it would
# probably have come from that operation.
copy_stack_trace(node.outputs + node.inputs, e)
copy_stack_trace(node.outputs + node.inputs[0], e)
return [e]
......@@ -4289,6 +4289,10 @@ def local_useless_reshape(node):
return False
input = node.inputs[0]
# Copy over stack trace
copy_stack_trace(node.outputs[0], input)
output = node.outputs[0]
output_shape = node.inputs[1]
......@@ -4598,6 +4602,7 @@ def local_fill_cut(node):
# from the removed fill op, it must come from the elemntwise op.
copy_stack_trace(node.outputs, rval)
if isinstance(rval, gof.Variable):
return rval.owner.outputs
else:
......@@ -5157,10 +5162,9 @@ def local_sum_prod_mul_by_scalar(node):
new_op_input_nb_elements = new_op_input.size
new_op_output = node.op(new_op_input)
if not len(non_scalars) == 0:
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, new_op_output)
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, new_op_output)
# If node.op is a T.elemwise.Prod, then the scalars need to be
# raised to the power of the number of elements in the input
......@@ -5186,7 +5190,7 @@ def local_sum_prod_mul_by_scalar(node):
ret = T.mul(*mul_inputs)
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, [ret] + mul_inputs)
copy_stack_trace(node.outputs, ret+mul_inputs)
return [ret]
......@@ -5196,7 +5200,7 @@ def local_sum_prod_mul_by_scalar(node):
# There are never errors in the negative op, thus
# we need only to copy over stacktrace from previous output node to
# the two new ops.
copy_stack_trace(node.outputs, [s, ret])
copy_stack_trace(node.outputs, s+ret)
return [ret]
......@@ -5212,8 +5216,8 @@ def local_elemwise_sub_zeros(node):
node.op.scalar_op == scalar.sub and
node.inputs[0] == node.inputs[1]):
res = T.zeros_like(node.inputs[0])
# Copy over stacktrace from previous output.
# Julian: Pascal, is this really necessary? Is there anyway zeros_like can ever fail?
copy_stack_trace(node.outputs, res)
return [res]
......
......@@ -6294,17 +6294,17 @@ class Test_local_useless_reshape(unittest.TestCase):
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check='all')
# TODO: Check that stack trace is maintained.
# Currently, stack trace gets removed by some other opt.
#assert check_stack_trace(f1, ops_to_check='all')
m2 = m0.excluding('local_useless_reshape')
m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f2, ops_to_check='all')
def test_2(self):
x = theano.tensor.matrix('x')
r = x.reshape([Shape_i(i)(x) for i in xrange(x.ndim)])
......@@ -6315,17 +6315,15 @@ class Test_local_useless_reshape(unittest.TestCase):
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check='all')
# TODO: Check that stack trace is maintained.
# Currently, stack trace gets removed by some other opt.
#assert check_stack_trace(f1, ops_to_check='all')
m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f2, ops_to_check='all')
class Test_local_reshape_to_dimshuffle(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论