提交 0b388f2d authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Fixed error in opt local_sum_prod_mul_by_scalar. Also fixed flake8 errors.

上级 68e959df
......@@ -4602,7 +4602,6 @@ def local_fill_cut(node):
# from the removed fill op, it must come from the elemntwise op.
copy_stack_trace(node.outputs, rval)
if isinstance(rval, gof.Variable):
return rval.owner.outputs
else:
......@@ -5190,7 +5189,7 @@ def local_sum_prod_mul_by_scalar(node):
ret = T.mul(*mul_inputs)
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, ret+mul_inputs)
copy_stack_trace(node.outputs,[ret]+mul_inputs)
return [ret]
......
......@@ -3677,7 +3677,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0, op=T.alloc)
assert (f([3, 3]) == 0).all()
def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize')
......@@ -3724,7 +3723,6 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.xor(x, x), mode=mode)
self.assert_eqs_const(f, 0)
def test_stacktrace(self):
mode = theano.compile.get_default_mode().including(
'local_useless_elemwise_comparison')
......@@ -6310,7 +6308,7 @@ class Test_local_useless_reshape(unittest.TestCase):
# TODO: Check that stack trace is maintained.
# Currently, stack trace gets removed by some other opt.
#assert check_stack_trace(f1, ops_to_check='all')
# assert check_stack_trace(f1, ops_to_check='all')
m2 = m0.excluding('local_useless_reshape')
......@@ -6331,7 +6329,7 @@ class Test_local_useless_reshape(unittest.TestCase):
# TODO: Check that stack trace is maintained.
# Currently, stack trace gets removed by some other opt.
#assert check_stack_trace(f1, ops_to_check='all')
# assert check_stack_trace(f1, ops_to_check='all')
m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2)
......@@ -6385,6 +6383,7 @@ def test_local_reshape_lift():
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check='last')
class Test_lift_transpose_through_dot(unittest.TestCase):
def simple_optimize(self, g):
out2in(opt.local_useless_elemwise).optimize(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论