提交 4078794d authored 作者: lamblin's avatar lamblin

Merge pull request #1432 from nouiz/fix_sum_sum_crash

Fix sum sum crash
......@@ -2653,8 +2653,13 @@ class MaxAndArgmax(Op):
axis = [axis]
elif isinstance(axis, (tuple, list)):
if len(axis) != 1:
list(axis)
axis = list(axis)
for idx in range(len(axis)):
if axis[idx] < 0:
axis[idx] += x.type.ndim
axis.sort()
if axis == range(-x.type.ndim, 0, 1):
axis = range(x.type.ndim)
assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple"
" axes. the max fct supports it.")
......@@ -2806,10 +2811,17 @@ def makeKeepDims(x, y, axis):
axis = range(x.type.ndim)
elif isinstance(axis, int):
axis = [axis]
newaxis = []
for a in axis:
if not isinstance(a, int):
raise ValueError("keepdims option can be used only with constant axis")
if a < 0:
a += x.type.ndim
newaxis.append(a)
i = 0
new_dims = []
for j, _ in enumerate(x.type.broadcastable):
if j in axis:
if j in newaxis:
new_dims.append('x')
else:
new_dims.append(i)
......
......@@ -3209,7 +3209,7 @@ def local_sum_sum(node):
for i in node.op.axis:
new_i = i
for ii in summed.owner.op.axis:
if i >= ii:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)
......
......@@ -13,9 +13,14 @@ class TestKeepDims:
elif isinstance(axis, int):
axis = [axis]
i = 0
newaxis = []
for a in axis:
if a < 0:
a += x.type.ndim
newaxis.append(a)
new_dims = []
for j, _ in enumerate(x.shape):
if j in axis:
if j in newaxis:
new_dims.append('x')
else:
new_dims.append(i)
......@@ -30,7 +35,9 @@ class TestKeepDims:
# 'max_and_argmax' has two outputs and can be specified with either
# a single or every axis:
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2]]:
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
[-1], [-2], [-3], [-1, -2, -3], [0, -1, -2],
[-2, -3, 2]]:
op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
......@@ -50,8 +57,8 @@ class TestKeepDims:
# the following ops can be specified with either a single axis or every
# axis:
for op in ([tensor.argmax, tensor.argmin]):
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2]]:
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
[-1], [-2], [-3], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
......@@ -72,7 +79,8 @@ class TestKeepDims:
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std, tensor.all, tensor.any,
tensor.max, tensor.min]):
for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]:
for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2],
[-1], [-2], [-3], [-1, -2], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
......
......@@ -3473,7 +3473,7 @@ class T_local_sum(unittest.TestCase):
def test_local_sum_all_to_none(self):
a = T.tensor3()
input = numpy.arange(3 * 3 * 3, dtype=config.floatX).reshape(3, 3, 3)
input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
f = theano.function([a], a.sum(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum())
......@@ -3493,36 +3493,50 @@ class T_local_sum(unittest.TestCase):
def test_local_sum_sum(self):
a = T.tensor3()
input = numpy.arange(3 * 3 * 3, dtype=config.floatX).reshape(3, 3, 3)
dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]
input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1),
((0, 1), 0), ((1, 2), 0), (0, (0, 1)),
(1, (0, 1)), (2, (0, 1))]
backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False
def my_sum(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.sum(d).sum(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.sum(d[1]).sum(d[0]).sum(dd)
else:
dd = sorted(dd)
return data.sum(d).sum(dd[1]).sum(dd[0])
try:
for d, dd in dims:
expected = my_sum(input, d, dd)
f = theano.function([a], a.sum(d).sum(dd), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum(dd))
assert numpy.allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims:
for d, dd in dims[:6]:
f = theano.function([a], a.sum(d).sum(dd).
sum(0), mode=self.mode)
sum(0), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum(dd).sum(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = theano.function([a], a.sum(d).sum(None), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum())
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = theano.function([a], a.sum(None).sum(), mode=self.mode)
assert numpy.allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1
f = theano.function([a], a.sum(None).sum(), mode=self.mode)
assert numpy.allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1
finally:
config.warn.sum_sum_bug = backup
def test_local_sum_alloc(self):
a = T.dtensor3()
input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4),
dtype='float64')
dtype='float64')
mode = self.mode.including('specialize').excluding('fusion')
for t_like,n_like,nb_nodes in [(tensor.zeros_like,numpy.zeros_like,(1,3,3,2)),
......@@ -3556,14 +3570,14 @@ class T_local_sum(unittest.TestCase):
try:
for d, dd in [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]:
f = theano.function([a], t_like(a).
sum(d).sum(dd), mode=mode)
sum(d).sum(dd), mode=mode)
assert numpy.allclose(f(input),
n_like(input).sum(d).sum(dd))
n_like(input).sum(d).sum(dd))
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op,
T.Sum) for node in topo])
T.Sum) for node in topo])
finally:
config.warn.sum_sum_bug = backup
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论