提交 dada6d72 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed Argmax for MaxAndArgmax, argmax now only returns a single result

上级 75247c10
......@@ -625,60 +625,60 @@ class T_Cast(unittest.TestCase):
b = f(a)
self.failUnless(numpy.all(b == numpy.arange(10, dtype = type2)))
class T_argmax(unittest.TestCase):
class T_max_and_argmax(unittest.TestCase):
def setUp(self):
numpy.random.seed(123784)
Argmax.debug = 0
MaxAndArgmax.debug = 0
def test0(self):
n = as_tensor(5.0)
v,i = eval_outputs(argmax(n))
v,i = eval_outputs(max_and_argmax(n))
self.failUnless(v == 5.0)
self.failUnless(i == 0)
def test1(self):
n = as_tensor([1,2,3,2,-6])
v,i = eval_outputs(argmax(n))
v,i = eval_outputs(max_and_argmax(n))
self.failUnless(v == 3)
self.failUnless(i == 2)
def test2(self):
n = as_tensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n))
v,i = eval_outputs(max_and_argmax(n))
self.failUnless(numpy.all(i == [0,1]))
def test2b(self):
n = as_tensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n,0))
v,i = eval_outputs(max_and_argmax(n,0))
self.failUnless(numpy.all(i == [0,1,1]))
def test2_invalid(self):
n = as_tensor(numpy.random.rand(2,3))
try:
eval_outputs(argmax(n,3))
eval_outputs(max_and_argmax(n,3))
except ValueError, e:
return
self.fail()
def test2_invalid_neg(self):
n = as_tensor(numpy.random.rand(2,3))
try:
eval_outputs(argmax(n,-3))
eval_outputs(max_and_argmax(n,-3))
except ValueError, e:
return
self.fail()
def test2_valid_neg(self):
n = as_tensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n,-1))
v,i = eval_outputs(max_and_argmax(n,-1))
self.failUnless(v.shape == (2,))
v,i = eval_outputs(argmax(n,-2))
v,i = eval_outputs(max_and_argmax(n,-2))
self.failUnless(v.shape == (3,))
def test3(self):
n = as_tensor(numpy.random.rand(2,3,4))
v,i = eval_outputs(argmax(n,0))
v,i = eval_outputs(max_and_argmax(n,0))
self.failUnless(v.shape == (3,4))
self.failUnless(i.shape == (3,4))
v,i = eval_outputs(argmax(n,1))
v,i = eval_outputs(max_and_argmax(n,1))
self.failUnless(v.shape == (2,4))
self.failUnless(i.shape == (2,4))
v,i = eval_outputs(argmax(n,2))
v,i = eval_outputs(max_and_argmax(n,2))
self.failUnless(v.shape == (2,3))
self.failUnless(i.shape == (2,3))
......
......@@ -484,11 +484,12 @@ class Shape(Op):
return [None]
shape = Shape()
class Argmax(Op):
class MaxAndArgmax(Op):
"""Calculate the max and argmax over a given axis"""
nin=2 # tensor, axis
nout=2 # max val, max idx
E_axis = 'invalid axis'
def make_node(self, x, axis=None):
x = _as_tensor(x)
if axis is None:
......@@ -502,17 +503,29 @@ class Argmax(Op):
def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.max(x, axis)
max_idx[0] = numpy.argmax(x, axis)
argmax = Argmax()
max_and_argmax = MaxAndArgmax()
def max(x, axis=None):
"""Return indexes of maximum elements obtained by iterating over given axis
Default axis is the last one.
"""
# In python (using MaxAndArgmax.perform()) this leads to an wasteful
# implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine.
return max_and_argmax(x,axis)[0]
def argmax(x, axis=None):
"""Return maximum elements obtained by iterating over given axis
Default axis is the last one.
"""
# In python (using Argmax.perform()) this leads to an wasteful
# In python (using MaxAndArgmax.perform()) this leads to an wasteful
# implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine.
return argmax(x,axis)[0]
return max_and_argmax(x,axis)[1]
##########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论