提交 cbf387b7 authored 作者: Frederic's avatar Frederic

Make MaxAndArgmax accept None as axis and update test to tests more axis.

上级 5ad3c667
...@@ -1885,7 +1885,11 @@ class MaxAndArgmax(Op): ...@@ -1885,7 +1885,11 @@ class MaxAndArgmax(Op):
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis] axis = [axis]
elif isinstance(axis, (tuple, list)): elif isinstance(axis, (tuple, list)):
assert len(axis) == 1, ("MaxAndArgmax don't support multiple" if len(axis) != 1:
list(axis)
axis.sort()
assert axis == range(x.type.ndim), (
"MaxAndArgmax don't support multiple"
" axis. the max fct support it.") " axis. the max fct support it.")
# we make the axis all positive to make the infer_shape work # we make the axis all positive to make the infer_shape work
# with negative axis # with negative axis
...@@ -1901,8 +1905,7 @@ class MaxAndArgmax(Op): ...@@ -1901,8 +1905,7 @@ class MaxAndArgmax(Op):
axis = _as_tensor_variable(axis) axis = _as_tensor_variable(axis)
inputs = [x, axis] inputs = [x, axis]
#TODO: figure things out if axis is a constant broadcastable = [False] * (x.type.ndim - len(axis.data))
broadcastable = [False] * (x.type.ndim - 1)
outputs = [tensor(x.type.dtype, broadcastable, name='max'), outputs = [tensor(x.type.dtype, broadcastable, name='max'),
tensor('int32', broadcastable, name='argmax')] tensor('int32', broadcastable, name='argmax')]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
...@@ -1920,6 +1923,10 @@ class MaxAndArgmax(Op): ...@@ -1920,6 +1923,10 @@ class MaxAndArgmax(Op):
axis = node.inputs[1] axis = node.inputs[1]
if axis is None: if axis is None:
return [(), ()] return [(), ()]
elif len(axis.data) == 0 and node.inputs[0].ndim:
return [(1,), (1,)]
elif python_all(axis.data == range(node.inputs[0].ndim)):
return [(), ()]
rval = tuple([ishape[i] for (i, b) in enumerate( rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i != axis.data]) node.inputs[0].type.broadcastable) if i != axis.data])
return [rval, rval] return [rval, rval]
......
...@@ -1476,25 +1476,13 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -1476,25 +1476,13 @@ class T_max_and_argmax(unittest.TestCase):
def test2(self): def test2(self):
data = numpy.random.rand(2, 3) data = numpy.random.rand(2, 3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
v, i = eval_outputs(max_and_argmax(n, -1)) for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
self.assertTrue(numpy.all(v == numpy.max(data, -1))) ([0, 1], None), ([1, 0], None)]:
self.assertTrue(numpy.all(i == numpy.argmax(data, -1))) v, i = eval_outputs(max_and_argmax(n, axis))
v = eval_outputs(max_and_argmax(n, -1)[0].shape) self.assertTrue(numpy.all(v == numpy.max(data, np_axis)))
assert v == (2) self.assertTrue(numpy.all(i == numpy.argmax(data, np_axis)))
v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape)
def test2b(self): assert tuple(v_shape) == numpy.max(data, np_axis).shape
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
v, i = eval_outputs(max_and_argmax(n, 0))
self.assertTrue(numpy.all(v == numpy.max(data, 0)))
self.assertTrue(numpy.all(i == numpy.argmax(data, 0)))
v = eval_outputs(max_and_argmax(n, 0)[0].shape)
assert v == (3)
v = eval_outputs(max_and_argmax(n, 1)[0].shape)
assert v == (2)
#currently not supported
#v = eval_outputs(max_and_argmax(n,[0,1])[0].shape)
#assert v.size==0
def test2_invalid(self): def test2_invalid(self):
n = as_tensor_variable(numpy.random.rand(2, 3)) n = as_tensor_variable(numpy.random.rand(2, 3))
...@@ -1542,22 +1530,15 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -1542,22 +1530,15 @@ class T_max_and_argmax(unittest.TestCase):
assert v == (3) assert v == (3)
def test3(self): def test3(self):
n = as_tensor_variable(numpy.random.rand(2, 3, 4)) data = numpy.random.rand(2, 3, 4)
v, i = eval_outputs(max_and_argmax(n, 0)) n = as_tensor_variable(data)
self.assertTrue(v.shape == (3, 4)) for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
self.assertTrue(i.shape == (3, 4)) ([0, 1, 2], None), ([1, 2, 0], None)]:
v, i = eval_outputs(max_and_argmax(n, 1)) v, i = eval_outputs(max_and_argmax(n, axis))
self.assertTrue(v.shape == (2, 4)) self.assertTrue(numpy.all(v == numpy.max(data, np_axis)))
self.assertTrue(i.shape == (2, 4)) self.assertTrue(numpy.all(i == numpy.argmax(data, np_axis)))
v, i = eval_outputs(max_and_argmax(n, 2)) v = eval_outputs(max_and_argmax(n, axis)[0].shape)
self.assertTrue(v.shape == (2, 3)) assert tuple(v) == numpy.max(data, np_axis).shape
self.assertTrue(i.shape == (2, 3))
v = eval_outputs(max_and_argmax(n, 0)[0].shape)
assert tuple(v) == (3, 4)
v = eval_outputs(max_and_argmax(n, 1)[0].shape)
assert tuple(v) == (2, 4)
v = eval_outputs(max_and_argmax(n, 2)[0].shape)
assert tuple(v) == (2, 3)
def test_grad(self): def test_grad(self):
data = numpy.random.rand(2, 3) data = numpy.random.rand(2, 3)
...@@ -1637,27 +1618,15 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -1637,27 +1618,15 @@ class T_argmin_argmax(unittest.TestCase):
assert len(v) == 0 assert len(v) == 0
def test2(self): def test2(self):
for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]:
data = numpy.random.rand(2, 3) data = numpy.random.rand(2, 3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
i = eval_outputs(fct(n, -1))
self.assertTrue(numpy.all(i == nfct(data, -1)))
v = eval_outputs(fct(n, -1).shape)
assert v == (2)
def test2b(self):
for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]: for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]:
data = numpy.random.rand(2, 3) for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
n = as_tensor_variable(data) ([0, 1], None), ([1, 0], None)]:
i = eval_outputs(fct(n, 0)) v = eval_outputs(fct(n, axis))
self.assertTrue(numpy.all(i == nfct(data, 0))) self.assertTrue(numpy.all(v == nfct(data, np_axis)))
v = eval_outputs(fct(n, 0).shape) v_shape = eval_outputs(fct(n, axis).shape)
assert v == (3) assert tuple(v_shape) == nfct(data, np_axis).shape
v = eval_outputs(fct(n, 1).shape)
assert v == (2)
#currently not supported
#v = eval_outputs(fct(n,[0,1]).shape)
#assert v.size==0
def test2_invalid(self): def test2_invalid(self):
for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]: for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]:
...@@ -1705,24 +1674,16 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -1705,24 +1674,16 @@ class T_argmin_argmax(unittest.TestCase):
assert v == (3) assert v == (3)
def test3(self): def test3(self):
data = numpy.random.rand(2, 3, 4)
n = as_tensor_variable(data)
for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]: for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]:
n = as_tensor_variable(numpy.random.rand(2, 3, 4)) for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (2, 2),
i = eval_outputs(fct(n, 0)) (None, None), ([0, 1, 2], None),
self.assertTrue(i.shape == (3, 4)) ([1, 0, 2], None)]:
self.assertTrue(numpy.all(i == nfct(n.value, 0))) v = eval_outputs(fct(n, axis))
i = eval_outputs(fct(n, 1)) self.assertTrue(numpy.all(v == nfct(data, np_axis)))
self.assertTrue(i.shape == (2, 4)) v_shape = eval_outputs(fct(n, axis).shape)
self.assertTrue(numpy.all(i == nfct(n.value, 1))) assert tuple(v_shape) == nfct(data, np_axis).shape
i = eval_outputs(fct(n, 2))
self.assertTrue(i.shape == (2, 3))
self.assertTrue(numpy.all(i == nfct(n.value, 2)))
v = eval_outputs(fct(n, 0).shape)
assert tuple(v) == (3, 4)
v = eval_outputs(fct(n, 1).shape)
assert tuple(v) == (2, 4)
v = eval_outputs(fct(n, 2).shape)
assert tuple(v) == (2, 3)
def test_grad_argmin(self): def test_grad_argmin(self):
data = numpy.random.rand(2, 3) data = numpy.random.rand(2, 3)
...@@ -1787,28 +1748,15 @@ class T_min_max(unittest.TestCase): ...@@ -1787,28 +1748,15 @@ class T_min_max(unittest.TestCase):
assert len(v) == 0 assert len(v) == 0
def test2(self): def test2(self):
for fct, nfct in [(max, numpy.max), (min, numpy.min)]:
data = numpy.random.rand(2, 3) data = numpy.random.rand(2, 3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
v = eval_outputs(fct(n, -1)) for fct, nfct in [(max, numpy.max), (min, numpy.min)]:
self.assertTrue(numpy.all(v == nfct(data, -1))) for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (None, None),
([0, 1], None), ([1, 0], None)]:
v = eval_outputs(fct(n, -1).shape) v = eval_outputs(fct(n, axis))
assert v == (2) self.assertTrue(numpy.all(v == nfct(data, np_axis)))
v_shape = eval_outputs(fct(n, axis).shape)
def test2b(self): assert tuple(v_shape) == nfct(data, np_axis).shape
for fct, nfct in [(max, numpy.max),(min, numpy.min)]:
data = numpy.random.rand(2, 3)
n = as_tensor_variable(data)
v = eval_outputs(fct(n, 0))
self.assertTrue(numpy.all(v == nfct(data, 0)))
v = eval_outputs(fct(n, 0).shape)
assert v == (3)
v = eval_outputs(fct(n, 1).shape)
assert v == (2)
v = eval_outputs(fct(n, [0, 1]).shape)
assert v.size == 0
def test2_invalid(self): def test2_invalid(self):
for fct in [max, min]: for fct in [max, min]:
...@@ -1856,43 +1804,16 @@ class T_min_max(unittest.TestCase): ...@@ -1856,43 +1804,16 @@ class T_min_max(unittest.TestCase):
assert v == (3) assert v == (3)
def test3(self): def test3(self):
data = numpy.random.rand(2, 3, 4)
n = as_tensor_variable(data)
for fct, nfct in [(max, numpy.max), (min, numpy.min)]: for fct, nfct in [(max, numpy.max), (min, numpy.min)]:
n = as_tensor_variable(numpy.random.rand(2, 3, 4)) for (axis, np_axis) in [(-1, -1), (0, 0), (1, 1), (2, 2),
v = eval_outputs(fct(n, 0)) (None, None), ([0, 1, 2], None),
self.assertTrue(v.shape == (3, 4)) ([1, 0, 2], None)]:
self.assertTrue(numpy.all(v == nfct(n.value, 0))) v = eval_outputs(fct(n, axis))
v = eval_outputs(fct(n, 1)) self.assertTrue(numpy.all(v == nfct(data, np_axis)))
self.assertTrue(v.shape == (2, 4)) v_shape = eval_outputs(fct(n, axis).shape)
self.assertTrue(numpy.all(v == nfct(n.value, 1))) assert tuple(v_shape) == nfct(data, np_axis).shape
v = eval_outputs(fct(n, 2))
self.assertTrue(v.shape == (2, 3))
self.assertTrue(numpy.all(v == nfct(n.value, 2)))
v = eval_outputs(fct(n, [0, 1]))
self.assertTrue(v.shape == (4,))
self.assertTrue(numpy.all(v == nfct(nfct(n.value, 1), 0)))
v = eval_outputs(fct(n, [0, 2]))
self.assertTrue(v.shape == (3,))
self.assertTrue(numpy.all(v == nfct(nfct(n.value, 2), 0)))
v = eval_outputs(fct(n, [1, 2]))
self.assertTrue(v.shape == (2,))
self.assertTrue(numpy.all(v == nfct(nfct(n.value, 2), 1)))
v = eval_outputs(fct(n, [0, 1, 2]))
self.assertTrue(v.shape == ())
v = eval_outputs(fct(n, 0).shape)
assert tuple(v) == (3, 4)
v = eval_outputs(fct(n, 1).shape)
assert tuple(v) == (2, 4)
v = eval_outputs(fct(n, 2).shape)
assert tuple(v) == (2, 3)
v = eval_outputs(fct(n, [0, 1]).shape)
self.assertTrue(v == (4,))
v = eval_outputs(fct(n, [0, 2]).shape)
self.assertTrue(v == (3,))
v = eval_outputs(fct(n, [1, 2]).shape)
self.assertTrue(v == (2,))
v = eval_outputs(fct(n, [0, 1, 2]).shape)
self.assertTrue(v.size == 0)
def test_grad_max(self): def test_grad_max(self):
data = numpy.random.rand(2, 3) data = numpy.random.rand(2, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论