提交 24d116ac authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5089 from lamblin/fix_maxargmax_zero_shape

Fix maxargmax zero shape
...@@ -1421,8 +1421,10 @@ class MaxAndArgmax(Op): ...@@ -1421,8 +1421,10 @@ class MaxAndArgmax(Op):
dtype='int64') dtype='int64')
# Not-reduced axes in front # Not-reduced axes in front
transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes))) transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes)))
reshaped_x = transposed_x.reshape(transposed_x.shape[:len(keep_axes)] + kept_shape = transposed_x.shape[:len(keep_axes)]
(-1,)) reduced_shape = transposed_x.shape[len(keep_axes):]
new_shape = kept_shape + (numpy.prod(reduced_shape),)
reshaped_x = transposed_x.reshape(new_shape)
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1), max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1),
dtype='int64') dtype='int64')
......
...@@ -3083,6 +3083,15 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -3083,6 +3083,15 @@ class T_max_and_argmax(unittest.TestCase):
v = eval_outputs(max_and_argmax(x, [1, -1])[0].shape) v = eval_outputs(max_and_argmax(x, [1, -1])[0].shape)
assert tuple(v) == numpy.max(data, (1, -1)).shape assert tuple(v) == numpy.max(data, (1, -1)).shape
def test_zero_shape(self):
x = tensor.matrix()
m, i = max_and_argmax(x, axis=1)
f = theano.function([x], [m, i])
xv = numpy.zeros((0, 4), dtype=floatX)
mv, iv = f(xv)
assert mv.shape == (0,)
assert iv.shape == (0,)
class T_argmin_argmax(unittest.TestCase): class T_argmin_argmax(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论