提交 4a1fc766 authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Implemented Pascals comments. Improved test for local_elemwise_sub_zeros.

上级 ac601d00
...@@ -4299,9 +4299,6 @@ def local_useless_reshape(node): ...@@ -4299,9 +4299,6 @@ def local_useless_reshape(node):
# This could hide errors if the user provides inconsistent shapes. # This could hide errors if the user provides inconsistent shapes.
if (input.ndim == 1 and output.ndim == 1 and if (input.ndim == 1 and output.ndim == 1 and
input.broadcastable == output.broadcastable): input.broadcastable == output.broadcastable):
# Copy over stack trace
copy_stack_trace(node.outputs[0], input)
return [input] return [input]
# Second case: all the shapes match the input shape # Second case: all the shapes match the input shape
...@@ -4309,9 +4306,6 @@ def local_useless_reshape(node): ...@@ -4309,9 +4306,6 @@ def local_useless_reshape(node):
if output_shape.owner and isinstance(output_shape.owner.op, Shape): if output_shape.owner and isinstance(output_shape.owner.op, Shape):
shape_input = output_shape.owner.inputs[0] shape_input = output_shape.owner.inputs[0]
if shape_input == input: if shape_input == input:
# Copy over stack trace
copy_stack_trace(node.outputs[0], input)
return [input] return [input]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
...@@ -4362,9 +4356,6 @@ def local_useless_reshape(node): ...@@ -4362,9 +4356,6 @@ def local_useless_reshape(node):
continue continue
if all(shape_match): if all(shape_match):
# Copy over stack trace
copy_stack_trace(node.outputs[0], input)
return [input] return [input]
# TODO later: if all the shapes except one match, we may want to # TODO later: if all the shapes except one match, we may want to
......
...@@ -3502,6 +3502,8 @@ def test_local_elemwise_sub_zeros(): ...@@ -3502,6 +3502,8 @@ def test_local_elemwise_sub_zeros():
f = function([scalar], scalar - scalar, mode=mode) f = function([scalar], scalar - scalar, mode=mode)
# Check optimized graph is correct # Check optimized graph is correct
assert isinstance(f.maker.fgraph.toposort()[0].op, T.Elemwise) assert isinstance(f.maker.fgraph.toposort()[0].op, T.Elemwise)
assert f.maker.fgraph.toposort()[0].op.name\
== 'Elemwise{second,no_inplace}'
assert isinstance(f.maker.fgraph.toposort()[0].inputs[1], assert isinstance(f.maker.fgraph.toposort()[0].inputs[1],
T.TensorConstant) or\ T.TensorConstant) or\
isinstance(f.maker.fgraph.toposort()[0].inputs[1], isinstance(f.maker.fgraph.toposort()[0].inputs[1],
...@@ -3514,6 +3516,8 @@ def test_local_elemwise_sub_zeros(): ...@@ -3514,6 +3516,8 @@ def test_local_elemwise_sub_zeros():
f = function([vect], vect - vect, mode=mode) f = function([vect], vect - vect, mode=mode)
# Check optimized graph is correct # Check optimized graph is correct
assert isinstance(f.maker.fgraph.toposort()[0].op, T.Elemwise) assert isinstance(f.maker.fgraph.toposort()[0].op, T.Elemwise)
assert f.maker.fgraph.toposort()[0].op.name\
== 'Elemwise{second,no_inplace}'
assert isinstance(f.maker.fgraph.toposort()[0].inputs[1], assert isinstance(f.maker.fgraph.toposort()[0].inputs[1],
T.TensorConstant) or\ T.TensorConstant) or\
isinstance(f.maker.fgraph.toposort()[0].inputs[1], isinstance(f.maker.fgraph.toposort()[0].inputs[1],
...@@ -3526,6 +3530,8 @@ def test_local_elemwise_sub_zeros(): ...@@ -3526,6 +3530,8 @@ def test_local_elemwise_sub_zeros():
f = function([mat], mat - mat, mode=mode) f = function([mat], mat - mat, mode=mode)
# Check optimized graph is correct # Check optimized graph is correct
assert isinstance(f.maker.fgraph.toposort()[0].op, T.Elemwise) assert isinstance(f.maker.fgraph.toposort()[0].op, T.Elemwise)
assert f.maker.fgraph.toposort()[0].op.name\
== 'Elemwise{second,no_inplace}'
assert isinstance(f.maker.fgraph.toposort()[0].inputs[1], assert isinstance(f.maker.fgraph.toposort()[0].inputs[1],
T.TensorConstant) or\ T.TensorConstant) or\
isinstance(f.maker.fgraph.toposort()[0].inputs[1], isinstance(f.maker.fgraph.toposort()[0].inputs[1],
...@@ -5416,8 +5422,6 @@ class T_local_sum_prod(unittest.TestCase): ...@@ -5416,8 +5422,6 @@ class T_local_sum_prod(unittest.TestCase):
(s1_val * s2_val * v_val * m_val).prod(), 2) (s1_val * s2_val * v_val * m_val).prod(), 2)
def test_local_sum_prod_all_to_none(self): def test_local_sum_prod_all_to_none(self):
# Julian: It appears that the opt local_sum_prod_mul_by_scalar
# is never used in any of these tests...
a = T.tensor3() a = T.tensor3()
input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
# test sum # test sum
...@@ -5652,10 +5656,8 @@ class T_local_sum_prod(unittest.TestCase): ...@@ -5652,10 +5656,8 @@ class T_local_sum_prod(unittest.TestCase):
finally: finally:
config.on_opt_error = backup config.on_opt_error = backup
def test_stack_trace(self): def test_local_sum_prod_mul_by_scalar_stack_trace(self):
""" # Test that stack trace is copied over correctly for local_sum_prod_mul_by_scalar.
Test that stack trace is copied over correctly.
"""
m0 = theano.compile.get_default_mode()\ m0 = theano.compile.get_default_mode()\
.excluding('inplace_elemwise_opt')\ .excluding('inplace_elemwise_opt')\
.including('canonicalize', 'specialize') .including('canonicalize', 'specialize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论