提交 10563359 authored 作者: Frederic's avatar Frederic

fix opt crash when taking multiple consecutive sum.

上级 6469c740
......@@ -3209,7 +3209,7 @@ def local_sum_sum(node):
for i in node.op.axis:
new_i = i
for ii in summed.owner.op.axis:
if i >= ii:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)
......
......@@ -3493,17 +3493,32 @@ class T_local_sum(unittest.TestCase):
def test_local_sum_sum(self):
a = T.tensor3()
input = numpy.arange(3 * 3 * 3, dtype=config.floatX).reshape(3, 3, 3)
dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]
input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1),
((0, 1), 0), ((1, 2), 0), (0, (0, 1)),
(1, (0, 1)), (2, (0, 1))]
backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False
def my_sum(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.sum(d).sum(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.sum(d[1]).sum(d[0]).sum(dd)
else:
dd = sorted(dd)
return data.sum(d).sum(dd[1]).sum(dd[0])
try:
for d, dd in dims:
expected = my_sum(input, d, dd)
f = theano.function([a], a.sum(d).sum(dd), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum(dd))
assert numpy.allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims:
for d, dd in dims[:6]:
f = theano.function([a], a.sum(d).sum(dd).
sum(0), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum(dd).sum(0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论