提交 41177b4c authored 作者: carriepl's avatar carriepl

Merge pull request #3455 from nouiz/bugfix_sum_prod

[BUG] in opt local_op_of_op
...@@ -4577,8 +4577,7 @@ def local_op_of_op(node): ...@@ -4577,8 +4577,7 @@ def local_op_of_op(node):
# doesn't affect other computations. # doesn't affect other computations.
if len(node_inps.clients) == 1: if len(node_inps.clients) == 1:
if (node_inps.owner and if (node_inps.owner and
(isinstance(node_inps.owner.op, T.elemwise.Prod) or (isinstance(node_inps.owner.op, node.op.__class__))):
isinstance(node_inps.owner.op, T.elemwise.Sum))):
# check to see either the inner or outer prod is doing a # check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it # product over all axis, in which case we can remove it
......
...@@ -4891,6 +4891,17 @@ class T_local_sum_prod(unittest.TestCase): ...@@ -4891,6 +4891,17 @@ class T_local_sum_prod(unittest.TestCase):
dd = sorted(dd) dd = sorted(dd)
return data.sum(d).sum(dd[1]).sum(dd[0]) return data.sum(d).sum(dd[1]).sum(dd[0])
def my_sum_prod(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.sum(d).prod(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.sum(d[1]).sum(d[0]).prod(dd)
else:
dd = sorted(dd)
return data.sum(d).prod(dd[1]).prod(dd[0])
try: try:
for d, dd in dims: for d, dd in dims:
expected = my_sum(input, d, dd) expected = my_sum(input, d, dd)
...@@ -4931,6 +4942,25 @@ class T_local_sum_prod(unittest.TestCase): ...@@ -4931,6 +4942,25 @@ class T_local_sum_prod(unittest.TestCase):
assert numpy.allclose(f(input), input.prod()) assert numpy.allclose(f(input), input.prod())
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
# test sum prod don't get opt.
for d, dd in dims:
expected = my_sum_prod(input, d, dd)
f = theano.function([a], a.sum(d).prod(dd), mode=self.mode)
assert numpy.allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 2
for d, dd in dims[:6]:
f = theano.function([a], a.sum(d).prod(dd).
prod(0), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).prod(dd).prod(0))
assert len(f.maker.fgraph.apply_nodes) == 2
for d in [0, 1, 2]:
f = theano.function([a], a.sum(d).prod(None), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).prod())
assert len(f.maker.fgraph.apply_nodes) == 2
f = theano.function([a], a.sum(None).prod(), mode=self.mode)
assert numpy.allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1
def test_local_sum_prod_alloc(self): def test_local_sum_prod_alloc(self):
a = T.dtensor3() a = T.dtensor3()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论