提交 0b388f2d authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Fixed error in opt local_sum_prod_mul_by_scalar. Also fixed flake8 errors.

上级 68e959df
...@@ -4602,7 +4602,6 @@ def local_fill_cut(node): ...@@ -4602,7 +4602,6 @@ def local_fill_cut(node):
# from the removed fill op, it must come from the elemntwise op. # from the removed fill op, it must come from the elemntwise op.
copy_stack_trace(node.outputs, rval) copy_stack_trace(node.outputs, rval)
if isinstance(rval, gof.Variable): if isinstance(rval, gof.Variable):
return rval.owner.outputs return rval.owner.outputs
else: else:
...@@ -5190,7 +5189,7 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -5190,7 +5189,7 @@ def local_sum_prod_mul_by_scalar(node):
ret = T.mul(*mul_inputs) ret = T.mul(*mul_inputs)
# Copy over stacktrace from previous output to new mul op, # Copy over stacktrace from previous output to new mul op,
# for same reason as above. # for same reason as above.
copy_stack_trace(node.outputs, ret+mul_inputs) copy_stack_trace(node.outputs,[ret]+mul_inputs)
return [ret] return [ret]
......
...@@ -3677,7 +3677,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3677,7 +3677,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0, op=T.alloc) self.assert_eqs_const(f, 0, op=T.alloc)
assert (f([3, 3]) == 0).all() assert (f([3, 3]) == 0).all()
def test_and(self): def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
...@@ -3724,7 +3723,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3724,7 +3723,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.xor(x, x), mode=mode) f = theano.function([x], T.xor(x, x), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
def test_stacktrace(self): def test_stacktrace(self):
mode = theano.compile.get_default_mode().including( mode = theano.compile.get_default_mode().including(
'local_useless_elemwise_comparison') 'local_useless_elemwise_comparison')
...@@ -6310,7 +6308,7 @@ class Test_local_useless_reshape(unittest.TestCase): ...@@ -6310,7 +6308,7 @@ class Test_local_useless_reshape(unittest.TestCase):
# TODO: Check that stack trace is maintained. # TODO: Check that stack trace is maintained.
# Currently, stack trace gets removed by some other opt. # Currently, stack trace gets removed by some other opt.
#assert check_stack_trace(f1, ops_to_check='all') # assert check_stack_trace(f1, ops_to_check='all')
m2 = m0.excluding('local_useless_reshape') m2 = m0.excluding('local_useless_reshape')
...@@ -6331,7 +6329,7 @@ class Test_local_useless_reshape(unittest.TestCase): ...@@ -6331,7 +6329,7 @@ class Test_local_useless_reshape(unittest.TestCase):
# TODO: Check that stack trace is maintained. # TODO: Check that stack trace is maintained.
# Currently, stack trace gets removed by some other opt. # Currently, stack trace gets removed by some other opt.
#assert check_stack_trace(f1, ops_to_check='all') # assert check_stack_trace(f1, ops_to_check='all')
m2 = m1.excluding('ShapeOpt') m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2) f2 = theano.function([x], r, mode=m2)
...@@ -6385,6 +6383,7 @@ def test_local_reshape_lift(): ...@@ -6385,6 +6383,7 @@ def test_local_reshape_lift():
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check='last') assert check_stack_trace(f, ops_to_check='last')
class Test_lift_transpose_through_dot(unittest.TestCase): class Test_lift_transpose_through_dot(unittest.TestCase):
def simple_optimize(self, g): def simple_optimize(self, g):
out2in(opt.local_useless_elemwise).optimize(g) out2in(opt.local_useless_elemwise).optimize(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论