提交 9dac94ad authored 作者: Frederic Bastien's avatar Frederic Bastien

implemented MaxAndArgMax.infer_shape and test. I discovered that we can't make a max on a matrix!

上级 64178ac8
......@@ -1324,6 +1324,15 @@ class MaxAndArgmax(Op):
x = _as_tensor_variable(x)
if axis is None:
axis = x.type.ndim - 1
if isinstance(axis,int):
axis = [axis]
#we make the axis all positive to make the infer_shape work with negative axis
if x.type.ndim>0:
for id,a in enumerate(axis):
if a<0:
if -a>x.type.ndim:
raise ValueError('axis out of range')
axis[id]=x.type.ndim+a
axis = _as_tensor_variable(axis)
inputs = [x, axis]
#TODO: figure things out if axis is a constant
......@@ -1334,6 +1343,14 @@ class MaxAndArgmax(Op):
def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.asarray(numpy.max(x, axis))
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int32')
def infer_shape(self, node, (ishape,axis_shape)):
axis=node.inputs[1]
if axis is None:
return [(),()]
rval = tuple([ishape[i] for (i,b) in enumerate(node.inputs[0].type.broadcastable) if i !=axis.data])
return [rval,rval]
def grad(self, (x, axis), (g_max, g_max_idx)):
# @warning: This only works if axis is 0, else the max is
# broadcasted wrong in the call to eq.
......
......@@ -615,12 +615,16 @@ class T_max_and_argmax(unittest.TestCase):
v,i = eval_outputs(max_and_argmax(n))
self.failUnless(v == 5.0)
self.failUnless(i == 0)
v = eval_outputs(max_and_argmax(n)[0].shape)
assert len(v)==0
def test1(self):
n = as_tensor_variable([1,2,3,2,-6])
v,i = eval_outputs(max_and_argmax(n))
self.failUnless(v == 3)
self.failUnless(i == 2)
v = eval_outputs(max_and_argmax(n)[0].shape)
assert len(v)==0
def test2(self):
data = numpy.random.rand(2,3)
......@@ -628,12 +632,22 @@ class T_max_and_argmax(unittest.TestCase):
v,i = eval_outputs(max_and_argmax(n))
self.failUnless(numpy.all(v == numpy.max(data,-1)))
self.failUnless(numpy.all(i == numpy.argmax(data,-1)))
v = eval_outputs(max_and_argmax(n)[0].shape)
assert v==(2)
def test2b(self):
data = numpy.random.rand(2,3)
n = as_tensor_variable(data)
v,i = eval_outputs(max_and_argmax(n,0))
self.failUnless(numpy.all(v == numpy.max(data,0)))
self.failUnless(numpy.all(i == numpy.argmax(data,0)))
v = eval_outputs(max_and_argmax(n,0)[0].shape)
assert v==(3)
v = eval_outputs(max_and_argmax(n,1)[0].shape)
assert v==(2)
# v = eval_outputs(max_and_argmax(n,[0,1])[0].shape)
# assert v==()
def test2_invalid(self):
n = as_tensor_variable(numpy.random.rand(2,3))
old_stderr = sys.stderr
......@@ -662,6 +676,11 @@ class T_max_and_argmax(unittest.TestCase):
self.failUnless(v.shape == (2,))
v,i = eval_outputs(max_and_argmax(n,-2))
self.failUnless(v.shape == (3,))
v = eval_outputs(max_and_argmax(n,-1)[0].shape)
assert v==(2)
v = eval_outputs(max_and_argmax(n,-2)[0].shape)
assert v==(3)
def test3(self):
n = as_tensor_variable(numpy.random.rand(2,3,4))
v,i = eval_outputs(max_and_argmax(n,0))
......@@ -673,6 +692,12 @@ class T_max_and_argmax(unittest.TestCase):
v,i = eval_outputs(max_and_argmax(n,2))
self.failUnless(v.shape == (2,3))
self.failUnless(i.shape == (2,3))
v = eval_outputs(max_and_argmax(n,0)[0].shape)
assert tuple(v)==(3,4)
v = eval_outputs(max_and_argmax(n,1)[0].shape)
assert tuple(v)==(2,4)
v = eval_outputs(max_and_argmax(n,2)[0].shape)
assert tuple(v)==(2,3)
class T_subtensor(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论