提交 dada6d72 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed Argmax for MaxAndArgmax, argmax now only returns a single result

上级 75247c10
...@@ -625,60 +625,60 @@ class T_Cast(unittest.TestCase): ...@@ -625,60 +625,60 @@ class T_Cast(unittest.TestCase):
b = f(a) b = f(a)
self.failUnless(numpy.all(b == numpy.arange(10, dtype = type2))) self.failUnless(numpy.all(b == numpy.arange(10, dtype = type2)))
class T_argmax(unittest.TestCase): class T_max_and_argmax(unittest.TestCase):
def setUp(self): def setUp(self):
numpy.random.seed(123784) numpy.random.seed(123784)
Argmax.debug = 0 MaxAndArgmax.debug = 0
def test0(self): def test0(self):
n = as_tensor(5.0) n = as_tensor(5.0)
v,i = eval_outputs(argmax(n)) v,i = eval_outputs(max_and_argmax(n))
self.failUnless(v == 5.0) self.failUnless(v == 5.0)
self.failUnless(i == 0) self.failUnless(i == 0)
def test1(self): def test1(self):
n = as_tensor([1,2,3,2,-6]) n = as_tensor([1,2,3,2,-6])
v,i = eval_outputs(argmax(n)) v,i = eval_outputs(max_and_argmax(n))
self.failUnless(v == 3) self.failUnless(v == 3)
self.failUnless(i == 2) self.failUnless(i == 2)
def test2(self): def test2(self):
n = as_tensor(numpy.random.rand(2,3)) n = as_tensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n)) v,i = eval_outputs(max_and_argmax(n))
self.failUnless(numpy.all(i == [0,1])) self.failUnless(numpy.all(i == [0,1]))
def test2b(self): def test2b(self):
n = as_tensor(numpy.random.rand(2,3)) n = as_tensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n,0)) v,i = eval_outputs(max_and_argmax(n,0))
self.failUnless(numpy.all(i == [0,1,1])) self.failUnless(numpy.all(i == [0,1,1]))
def test2_invalid(self): def test2_invalid(self):
n = as_tensor(numpy.random.rand(2,3)) n = as_tensor(numpy.random.rand(2,3))
try: try:
eval_outputs(argmax(n,3)) eval_outputs(max_and_argmax(n,3))
except ValueError, e: except ValueError, e:
return return
self.fail() self.fail()
def test2_invalid_neg(self): def test2_invalid_neg(self):
n = as_tensor(numpy.random.rand(2,3)) n = as_tensor(numpy.random.rand(2,3))
try: try:
eval_outputs(argmax(n,-3)) eval_outputs(max_and_argmax(n,-3))
except ValueError, e: except ValueError, e:
return return
self.fail() self.fail()
def test2_valid_neg(self): def test2_valid_neg(self):
n = as_tensor(numpy.random.rand(2,3)) n = as_tensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n,-1)) v,i = eval_outputs(max_and_argmax(n,-1))
self.failUnless(v.shape == (2,)) self.failUnless(v.shape == (2,))
v,i = eval_outputs(argmax(n,-2)) v,i = eval_outputs(max_and_argmax(n,-2))
self.failUnless(v.shape == (3,)) self.failUnless(v.shape == (3,))
def test3(self): def test3(self):
n = as_tensor(numpy.random.rand(2,3,4)) n = as_tensor(numpy.random.rand(2,3,4))
v,i = eval_outputs(argmax(n,0)) v,i = eval_outputs(max_and_argmax(n,0))
self.failUnless(v.shape == (3,4)) self.failUnless(v.shape == (3,4))
self.failUnless(i.shape == (3,4)) self.failUnless(i.shape == (3,4))
v,i = eval_outputs(argmax(n,1)) v,i = eval_outputs(max_and_argmax(n,1))
self.failUnless(v.shape == (2,4)) self.failUnless(v.shape == (2,4))
self.failUnless(i.shape == (2,4)) self.failUnless(i.shape == (2,4))
v,i = eval_outputs(argmax(n,2)) v,i = eval_outputs(max_and_argmax(n,2))
self.failUnless(v.shape == (2,3)) self.failUnless(v.shape == (2,3))
self.failUnless(i.shape == (2,3)) self.failUnless(i.shape == (2,3))
......
...@@ -484,11 +484,12 @@ class Shape(Op): ...@@ -484,11 +484,12 @@ class Shape(Op):
return [None] return [None]
shape = Shape() shape = Shape()
class Argmax(Op): class MaxAndArgmax(Op):
"""Calculate the max and argmax over a given axis""" """Calculate the max and argmax over a given axis"""
nin=2 # tensor, axis nin=2 # tensor, axis
nout=2 # max val, max idx nout=2 # max val, max idx
E_axis = 'invalid axis' E_axis = 'invalid axis'
def make_node(self, x, axis=None): def make_node(self, x, axis=None):
x = _as_tensor(x) x = _as_tensor(x)
if axis is None: if axis is None:
...@@ -502,17 +503,29 @@ class Argmax(Op): ...@@ -502,17 +503,29 @@ class Argmax(Op):
def perform(self, node, (x, axis), (max, max_idx)): def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.max(x, axis) max[0] = numpy.max(x, axis)
max_idx[0] = numpy.argmax(x, axis) max_idx[0] = numpy.argmax(x, axis)
argmax = Argmax() max_and_argmax = MaxAndArgmax()
def max(x, axis=None): def max(x, axis=None):
"""Return indexes of maximum elements obtained by iterating over given axis
Default axis is the last one.
"""
# In python (using MaxAndArgmax.perform()) this leads to an wasteful
# implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine.
return max_and_argmax(x,axis)[0]
def argmax(x, axis=None):
"""Return maximum elements obtained by iterating over given axis """Return maximum elements obtained by iterating over given axis
Default axis is the last one. Default axis is the last one.
""" """
# In python (using Argmax.perform()) this leads to an wasteful # In python (using MaxAndArgmax.perform()) this leads to an wasteful
# implementation that goes through the data twice instead of once # implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine. # but when Argmax.c_impl() is in place, it should be fine.
return argmax(x,axis)[0] return max_and_argmax(x,axis)[1]
########################## ##########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论