提交 cbf387b7 authored 作者: Frederic's avatar Frederic

Make MaxAndArgmax accept None as axis and update test to tests more axis.

上级 5ad3c667
......@@ -1885,8 +1885,12 @@ class MaxAndArgmax(Op):
if isinstance(axis, int):
axis = [axis]
elif isinstance(axis, (tuple, list)):
assert len(axis) == 1, ("MaxAndArgmax don't support multiple"
" axis. the max fct support it.")
if len(axis) != 1:
list(axis)
axis.sort()
assert axis == range(x.type.ndim), (
"MaxAndArgmax don't support multiple"
" axis. the max fct support it.")
# we make the axis all positive to make the infer_shape work
# with negative axis
if x.type.ndim > 0 and axis is not None:
......@@ -1901,8 +1905,7 @@ class MaxAndArgmax(Op):
axis = _as_tensor_variable(axis)
inputs = [x, axis]
#TODO: figure things out if axis is a constant
broadcastable = [False] * (x.type.ndim - 1)
broadcastable = [False] * (x.type.ndim - len(axis.data))
outputs = [tensor(x.type.dtype, broadcastable, name='max'),
tensor('int32', broadcastable, name='argmax')]
return Apply(self, inputs, outputs)
......@@ -1920,6 +1923,10 @@ class MaxAndArgmax(Op):
axis = node.inputs[1]
if axis is None:
return [(), ()]
elif len(axis.data) == 0 and node.inputs[0].ndim:
return [(1,), (1,)]
elif python_all(axis.data == range(node.inputs[0].ndim)):
return [(), ()]
rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i != axis.data])
return [rval, rval]
......
......@@ -1476,25 +1476,13 @@ class T_max_and_argmax(unittest.TestCase):
def test2(self):
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
v, i = eval_outputs(max_and_argmax(n, -1))
self.assertTrue(numpy.all(v == numpy.max(data, -1)))
self.assertTrue(numpy.all(i == numpy.argmax(data, -1)))
v = eval_outputs(max_and_argmax(n, -1)[0].shape)
assert v == (2)
def test2b(self):
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
v, i = eval_outputs(max_and_argmax(n, 0))
self.assertTrue(numpy.all(v == numpy.max(data, 0)))
self.assertTrue(numpy.all(i == numpy.argmax(data, 0)))
v = eval_outputs(max_and_argmax(n, 0)[0].shape)
assert v == (3)
v = eval_outputs(max_and_argmax(n, 1)[0].shape)
assert v == (2)
#currently not supported
#v = eval_outputs(max_and_argmax(n,[0,1])[0].shape)
#assert v.size==0
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1], None), ([1, 0], None)]:
v, i = eval_outputs(max_and_argmax(n, axis))
self.assertTrue(numpy.all(v == numpy.max(data, np_axis)))
self.assertTrue(numpy.all(i == numpy.argmax(data, np_axis)))
v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape)
assert tuple(v_shape) == numpy.max(data, np_axis).shape
def test2_invalid(self):
n = as_tensor_variable(numpy.random.rand(2, 3))
......@@ -1542,22 +1530,15 @@ class T_max_and_argmax(unittest.TestCase):
assert v == (3)
def test3(self):
n = as_tensor_variable(numpy.random.rand(2, 3, 4))
v, i = eval_outputs(max_and_argmax(n, 0))
self.assertTrue(v.shape == (3, 4))
self.assertTrue(i.shape == (3, 4))
v, i = eval_outputs(max_and_argmax(n, 1))
self.assertTrue(v.shape == (2, 4))
self.assertTrue(i.shape == (2, 4))
v, i = eval_outputs(max_and_argmax(n, 2))
self.assertTrue(v.shape == (2, 3))
self.assertTrue(i.shape == (2, 3))
v = eval_outputs(max_and_argmax(n, 0)[0].shape)
assert tuple(v) == (3, 4)
v = eval_outputs(max_and_argmax(n, 1)[0].shape)
assert tuple(v) == (2, 4)
v = eval_outputs(max_and_argmax(n, 2)[0].shape)
assert tuple(v) == (2, 3)
data = numpy.random.rand(2, 3, 4)
n = as_tensor_variable(data)
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1, 2], None), ([1, 2, 0], None)]:
v, i = eval_outputs(max_and_argmax(n, axis))
self.assertTrue(numpy.all(v == numpy.max(data, np_axis)))
self.assertTrue(numpy.all(i == numpy.argmax(data, np_axis)))
v = eval_outputs(max_and_argmax(n, axis)[0].shape)
assert tuple(v) == numpy.max(data, np_axis).shape
def test_grad(self):
data = numpy.random.rand(2, 3)
......@@ -1637,27 +1618,15 @@ class T_argmin_argmax(unittest.TestCase):
assert len(v) == 0
def test2(self):
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]:
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
i = eval_outputs(fct(n, -1))
self.assertTrue(numpy.all(i == nfct(data, -1)))
v = eval_outputs(fct(n, -1).shape)
assert v == (2)
def test2b(self):
for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]:
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
i = eval_outputs(fct(n, 0))
self.assertTrue(numpy.all(i == nfct(data, 0)))
v = eval_outputs(fct(n, 0).shape)
assert v == (3)
v = eval_outputs(fct(n, 1).shape)
assert v == (2)
#currently not supported
#v = eval_outputs(fct(n,[0,1]).shape)
#assert v.size==0
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1], None), ([1, 0], None)]:
v = eval_outputs(fct(n, axis))
self.assertTrue(numpy.all(v == nfct(data, np_axis)))
v_shape = eval_outputs(fct(n, axis).shape)
assert tuple(v_shape) == nfct(data, np_axis).shape
def test2_invalid(self):
for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]:
......@@ -1705,24 +1674,16 @@ class T_argmin_argmax(unittest.TestCase):
assert v == (3)
def test3(self):
data = numpy.random.rand(2, 3, 4)
n = as_tensor_variable(data)
for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]:
n = as_tensor_variable(numpy.random.rand(2, 3, 4))
i = eval_outputs(fct(n, 0))
self.assertTrue(i.shape == (3, 4))
self.assertTrue(numpy.all(i == nfct(n.value, 0)))
i = eval_outputs(fct(n, 1))
self.assertTrue(i.shape == (2, 4))
self.assertTrue(numpy.all(i == nfct(n.value, 1)))
i = eval_outputs(fct(n, 2))
self.assertTrue(i.shape == (2, 3))
self.assertTrue(numpy.all(i == nfct(n.value, 2)))
v = eval_outputs(fct(n, 0).shape)
assert tuple(v) == (3, 4)
v = eval_outputs(fct(n, 1).shape)
assert tuple(v) == (2, 4)
v = eval_outputs(fct(n, 2).shape)
assert tuple(v) == (2, 3)
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (2, 2),
(None, None), ([0, 1, 2], None),
([1, 0, 2], None)]:
v = eval_outputs(fct(n, axis))
self.assertTrue(numpy.all(v == nfct(data, np_axis)))
v_shape = eval_outputs(fct(n, axis).shape)
assert tuple(v_shape) == nfct(data, np_axis).shape
def test_grad_argmin(self):
data = numpy.random.rand(2, 3)
......@@ -1787,28 +1748,15 @@ class T_min_max(unittest.TestCase):
assert len(v) == 0
def test2(self):
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
for fct, nfct in [(max, numpy.max), (min, numpy.min)]:
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
v = eval_outputs(fct(n, -1))
self.assertTrue(numpy.all(v == nfct(data, -1)))
v = eval_outputs(fct(n, -1).shape)
assert v == (2)
def test2b(self):
for fct, nfct in [(max, numpy.max),(min, numpy.min)]:
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
v = eval_outputs(fct(n, 0))
self.assertTrue(numpy.all(v == nfct(data, 0)))
v = eval_outputs(fct(n, 0).shape)
assert v == (3)
v = eval_outputs(fct(n, 1).shape)
assert v == (2)
v = eval_outputs(fct(n, [0, 1]).shape)
assert v.size == 0
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1], None), ([1, 0], None)]:
v = eval_outputs(fct(n, axis))
self.assertTrue(numpy.all(v == nfct(data, np_axis)))
v_shape = eval_outputs(fct(n, axis).shape)
assert tuple(v_shape) == nfct(data, np_axis).shape
def test2_invalid(self):
for fct in [max, min]:
......@@ -1856,43 +1804,16 @@ class T_min_max(unittest.TestCase):
assert v == (3)
def test3(self):
data = numpy.random.rand(2, 3, 4)
n = as_tensor_variable(data)
for fct, nfct in [(max, numpy.max), (min, numpy.min)]:
n = as_tensor_variable(numpy.random.rand(2, 3, 4))
v = eval_outputs(fct(n, 0))
self.assertTrue(v.shape == (3, 4))
self.assertTrue(numpy.all(v == nfct(n.value, 0)))
v = eval_outputs(fct(n, 1))
self.assertTrue(v.shape == (2, 4))
self.assertTrue(numpy.all(v == nfct(n.value, 1)))
v = eval_outputs(fct(n, 2))
self.assertTrue(v.shape == (2, 3))
self.assertTrue(numpy.all(v == nfct(n.value, 2)))
v = eval_outputs(fct(n, [0, 1]))
self.assertTrue(v.shape == (4,))
self.assertTrue(numpy.all(v == nfct(nfct(n.value, 1), 0)))
v = eval_outputs(fct(n, [0, 2]))
self.assertTrue(v.shape == (3,))
self.assertTrue(numpy.all(v == nfct(nfct(n.value, 2), 0)))
v = eval_outputs(fct(n, [1, 2]))
self.assertTrue(v.shape == (2,))
self.assertTrue(numpy.all(v == nfct(nfct(n.value, 2), 1)))
v = eval_outputs(fct(n, [0, 1, 2]))
self.assertTrue(v.shape == ())
v = eval_outputs(fct(n, 0).shape)
assert tuple(v) == (3, 4)
v = eval_outputs(fct(n, 1).shape)
assert tuple(v) == (2, 4)
v = eval_outputs(fct(n, 2).shape)
assert tuple(v) == (2, 3)
v = eval_outputs(fct(n, [0, 1]).shape)
self.assertTrue(v == (4,))
v = eval_outputs(fct(n, [0, 2]).shape)
self.assertTrue(v == (3,))
v = eval_outputs(fct(n, [1, 2]).shape)
self.assertTrue(v == (2,))
v = eval_outputs(fct(n, [0, 1, 2]).shape)
self.assertTrue(v.size == 0)
for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (2, 2),
(None, None), ([0, 1, 2], None),
([1, 0, 2], None)]:
v = eval_outputs(fct(n, axis))
self.assertTrue(numpy.all(v == nfct(data, np_axis)))
v_shape = eval_outputs(fct(n, axis).shape)
assert tuple(v_shape) == nfct(data, np_axis).shape
def test_grad_max(self):
data = numpy.random.rand(2, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论