提交 4078794d authored 作者: lamblin's avatar lamblin

Merge pull request #1432 from nouiz/fix_sum_sum_crash

Fix sum sum crash
...@@ -2653,8 +2653,13 @@ class MaxAndArgmax(Op): ...@@ -2653,8 +2653,13 @@ class MaxAndArgmax(Op):
axis = [axis] axis = [axis]
elif isinstance(axis, (tuple, list)): elif isinstance(axis, (tuple, list)):
if len(axis) != 1: if len(axis) != 1:
list(axis) axis = list(axis)
for idx in range(len(axis)):
if axis[idx] < 0:
axis[idx] += x.type.ndim
axis.sort() axis.sort()
if axis == range(-x.type.ndim, 0, 1):
axis = range(x.type.ndim)
assert axis == range(x.type.ndim), ( assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple" "MaxAndArgmax does not support multiple"
" axes. the max fct supports it.") " axes. the max fct supports it.")
...@@ -2806,10 +2811,17 @@ def makeKeepDims(x, y, axis): ...@@ -2806,10 +2811,17 @@ def makeKeepDims(x, y, axis):
axis = range(x.type.ndim) axis = range(x.type.ndim)
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
newaxis = []
for a in axis:
if not isinstance(a, int):
raise ValueError("keepdims option can be used only with constant axis")
if a < 0:
a += x.type.ndim
newaxis.append(a)
i = 0 i = 0
new_dims = [] new_dims = []
for j, _ in enumerate(x.type.broadcastable): for j, _ in enumerate(x.type.broadcastable):
if j in axis: if j in newaxis:
new_dims.append('x') new_dims.append('x')
else: else:
new_dims.append(i) new_dims.append(i)
......
...@@ -3209,7 +3209,7 @@ def local_sum_sum(node): ...@@ -3209,7 +3209,7 @@ def local_sum_sum(node):
for i in node.op.axis: for i in node.op.axis:
new_i = i new_i = i
for ii in summed.owner.op.axis: for ii in summed.owner.op.axis:
if i >= ii: if new_i >= ii:
new_i += 1 new_i += 1
assert new_i not in newaxis assert new_i not in newaxis
newaxis.append(new_i) newaxis.append(new_i)
......
...@@ -13,9 +13,14 @@ class TestKeepDims: ...@@ -13,9 +13,14 @@ class TestKeepDims:
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
i = 0 i = 0
newaxis = []
for a in axis:
if a < 0:
a += x.type.ndim
newaxis.append(a)
new_dims = [] new_dims = []
for j, _ in enumerate(x.shape): for j, _ in enumerate(x.shape):
if j in axis: if j in newaxis:
new_dims.append('x') new_dims.append('x')
else: else:
new_dims.append(i) new_dims.append(i)
...@@ -30,7 +35,9 @@ class TestKeepDims: ...@@ -30,7 +35,9 @@ class TestKeepDims:
# 'max_and_argmax' has two outputs and can be specified with either # 'max_and_argmax' has two outputs and can be specified with either
# a single or every axis: # a single or every axis:
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2]]: for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
[-1], [-2], [-3], [-1, -2, -3], [0, -1, -2],
[-2, -3, 2]]:
op = tensor.max_and_argmax op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0]) keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
...@@ -50,8 +57,8 @@ class TestKeepDims: ...@@ -50,8 +57,8 @@ class TestKeepDims:
# the following ops can be specified with either a single axis or every # the following ops can be specified with either a single axis or every
# axis: # axis:
for op in ([tensor.argmax, tensor.argmin]): for op in ([tensor.argmax, tensor.argmin]):
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2]]: [-1], [-2], [-3], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
...@@ -72,7 +79,8 @@ class TestKeepDims: ...@@ -72,7 +79,8 @@ class TestKeepDims:
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var, for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std, tensor.all, tensor.any, tensor.std, tensor.all, tensor.any,
tensor.max, tensor.min]): tensor.max, tensor.min]):
for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]: for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2],
[-1], [-2], [-3], [-1, -2], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
......
...@@ -3473,7 +3473,7 @@ class T_local_sum(unittest.TestCase): ...@@ -3473,7 +3473,7 @@ class T_local_sum(unittest.TestCase):
def test_local_sum_all_to_none(self): def test_local_sum_all_to_none(self):
a = T.tensor3() a = T.tensor3()
input = numpy.arange(3 * 3 * 3, dtype=config.floatX).reshape(3, 3, 3) input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
f = theano.function([a], a.sum(), mode=self.mode) f = theano.function([a], a.sum(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum()) assert numpy.allclose(f(input), input.sum())
...@@ -3493,36 +3493,50 @@ class T_local_sum(unittest.TestCase): ...@@ -3493,36 +3493,50 @@ class T_local_sum(unittest.TestCase):
def test_local_sum_sum(self): def test_local_sum_sum(self):
a = T.tensor3() a = T.tensor3()
input = numpy.arange(3 * 3 * 3, dtype=config.floatX).reshape(3, 3, 3) input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)] dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1),
((0, 1), 0), ((1, 2), 0), (0, (0, 1)),
(1, (0, 1)), (2, (0, 1))]
backup = config.warn.sum_sum_bug backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False config.warn.sum_sum_bug = False
def my_sum(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.sum(d).sum(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.sum(d[1]).sum(d[0]).sum(dd)
else:
dd = sorted(dd)
return data.sum(d).sum(dd[1]).sum(dd[0])
try: try:
for d, dd in dims: for d, dd in dims:
expected = my_sum(input, d, dd)
f = theano.function([a], a.sum(d).sum(dd), mode=self.mode) f = theano.function([a], a.sum(d).sum(dd), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum(dd)) assert numpy.allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims: for d, dd in dims[:6]:
f = theano.function([a], a.sum(d).sum(dd). f = theano.function([a], a.sum(d).sum(dd).
sum(0), mode=self.mode) sum(0), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum(dd).sum(0)) assert numpy.allclose(f(input), input.sum(d).sum(dd).sum(0))
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]: for d in [0, 1, 2]:
f = theano.function([a], a.sum(d).sum(None), mode=self.mode) f = theano.function([a], a.sum(d).sum(None), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum()) assert numpy.allclose(f(input), input.sum(d).sum())
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]: f = theano.function([a], a.sum(None).sum(), mode=self.mode)
f = theano.function([a], a.sum(None).sum(), mode=self.mode) assert numpy.allclose(f(input), input.sum())
assert numpy.allclose(f(input), input.sum()) assert len(f.maker.fgraph.apply_nodes) == 1
assert len(f.maker.fgraph.apply_nodes) == 1
finally: finally:
config.warn.sum_sum_bug = backup config.warn.sum_sum_bug = backup
def test_local_sum_alloc(self): def test_local_sum_alloc(self):
a = T.dtensor3() a = T.dtensor3()
input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4), input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4),
dtype='float64') dtype='float64')
mode = self.mode.including('specialize').excluding('fusion') mode = self.mode.including('specialize').excluding('fusion')
for t_like,n_like,nb_nodes in [(tensor.zeros_like,numpy.zeros_like,(1,3,3,2)), for t_like,n_like,nb_nodes in [(tensor.zeros_like,numpy.zeros_like,(1,3,3,2)),
...@@ -3556,14 +3570,14 @@ class T_local_sum(unittest.TestCase): ...@@ -3556,14 +3570,14 @@ class T_local_sum(unittest.TestCase):
try: try:
for d, dd in [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]: for d, dd in [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]:
f = theano.function([a], t_like(a). f = theano.function([a], t_like(a).
sum(d).sum(dd), mode=mode) sum(d).sum(dd), mode=mode)
assert numpy.allclose(f(input), assert numpy.allclose(f(input),
n_like(input).sum(d).sum(dd)) n_like(input).sum(d).sum(dd))
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3] assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3]
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, assert not any([isinstance(node.op,
T.Sum) for node in topo]) T.Sum) for node in topo])
finally: finally:
config.warn.sum_sum_bug = backup config.warn.sum_sum_bug = backup
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论