提交 180c6c90 authored 作者: nouiz's avatar nouiz

Merge pull request #259 from delallea/argmax_dtype

Changed argmax dtype from int32 to int64
......@@ -1907,7 +1907,7 @@ class MaxAndArgmax(Op):
inputs = [x, axis]
broadcastable = [False] * (x.type.ndim - len(axis.data))
outputs = [tensor(x.type.dtype, broadcastable, name='max'),
tensor('int32', broadcastable, name='argmax')]
tensor('int64', broadcastable, name='argmax')]
return Apply(self, inputs, outputs)
def perform(self, node, inp, outs):
......@@ -1916,7 +1916,7 @@ class MaxAndArgmax(Op):
if python_all(axis == range(x.ndim)):
axis = None
max[0] = numpy.asarray(numpy.max(x, axis))
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int32')
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64')
def infer_shape(self, node, shapes):
ishape, axis_shape = shapes
......
......@@ -1460,6 +1460,7 @@ class T_max_and_argmax(unittest.TestCase):
v, i = eval_outputs(max_and_argmax(n))
self.assertTrue(v == 5.0)
self.assertTrue(i == 0)
assert i.dtype == 'int64'
v = eval_outputs(max_and_argmax(n)[0].shape)
assert len(v) == 0
v = eval_outputs(max_and_argmax(n)[1].shape)
......@@ -1470,6 +1471,7 @@ class T_max_and_argmax(unittest.TestCase):
v, i = eval_outputs(max_and_argmax(n))
self.assertTrue(v == 3)
self.assertTrue(i == 2)
assert i.dtype == 'int64'
v = eval_outputs(max_and_argmax(n)[0].shape)
assert len(v) == 0
......@@ -1479,6 +1481,7 @@ class T_max_and_argmax(unittest.TestCase):
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1], None), ([1, 0], None)]:
v, i = eval_outputs(max_and_argmax(n, axis))
assert i.dtype == 'int64'
self.assertTrue(numpy.all(v == numpy.max(data, np_axis)))
self.assertTrue(numpy.all(i == numpy.argmax(data, np_axis)))
v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape)
......@@ -1515,11 +1518,13 @@ class T_max_and_argmax(unittest.TestCase):
def test2_valid_neg(self):
n = as_tensor_variable(numpy.random.rand(2, 3))
v, i = eval_outputs(max_and_argmax(n, -1))
assert i.dtype == 'int64'
self.assertTrue(v.shape == (2,))
self.assertTrue(i.shape == (2,))
self.assertTrue(numpy.all(v == numpy.max(n.value, -1)))
self.assertTrue(numpy.all(i == numpy.argmax(n.value, -1)))
v, i = eval_outputs(max_and_argmax(n, -2))
assert i.dtype == 'int64'
self.assertTrue(v.shape == (3,))
self.assertTrue(i.shape == (3,))
self.assertTrue(numpy.all(v == numpy.max(n.value, -2)))
......@@ -1535,6 +1540,7 @@ class T_max_and_argmax(unittest.TestCase):
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1, 2], None), ([1, 2, 0], None)]:
v, i = eval_outputs(max_and_argmax(n, axis))
assert i.dtype == 'int64'
self.assertTrue(numpy.all(v == numpy.max(data, np_axis)))
self.assertTrue(numpy.all(i == numpy.argmax(data, np_axis)))
v = eval_outputs(max_and_argmax(n, axis)[0].shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论