提交 4ae9cf4b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Explicitly calls min, max, etc. with axis=-1 to silence warning.

上级 db5a9c1e
......@@ -890,7 +890,7 @@ def crossentropy_softmax_max_and_argmax_1hot_with_bias(x, b, y_idx, **kwargs):
the appropriate information (i.e. the max probability)?
"""
(xent, softmax) = crossentropy_softmax_1hot_with_bias(x, b, y_idx, **kwargs)
(max_pr, argmax) = tensor.max_and_argmax(softmax)
(max_pr, argmax) = tensor.max_and_argmax(softmax, axis=-1)
return (xent, softmax, max_pr, argmax)
def crossentropy_softmax_max_and_argmax_1hot(x, y_idx, **kwargs):
b = tensor.zeros_like(x[0,:])
......
......@@ -875,10 +875,10 @@ class T_max_and_argmax(unittest.TestCase):
def test2(self):
data = numpy.random.rand(2,3)
n = as_tensor_variable(data)
v,i = eval_outputs(max_and_argmax(n))
v,i = eval_outputs(max_and_argmax(n,-1))
self.failUnless(numpy.all(v == numpy.max(data,-1)))
self.failUnless(numpy.all(i == numpy.argmax(data,-1)))
v = eval_outputs(max_and_argmax(n)[0].shape)
v = eval_outputs(max_and_argmax(n,-1)[0].shape)
assert v==(2)
def test2b(self):
......@@ -977,8 +977,8 @@ class T_max_and_argmax(unittest.TestCase):
#test grad of max
#axis is the last one
utt.verify_grad(lambda v: max_and_argmax(v)[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v)[1], [data])
utt.verify_grad(lambda v: max_and_argmax(v,axis=-1)[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v,axis=-1)[1], [data])
utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[1], [data])
......@@ -1022,9 +1022,9 @@ class T_argmin_argmax(unittest.TestCase):
for fct,nfct in [(argmax,numpy.argmax),(argmin,numpy.argmin)]:
data = numpy.random.rand(2,3)
n = as_tensor_variable(data)
i = eval_outputs(fct(n))
i = eval_outputs(fct(n,-1))
self.failUnless(numpy.all(i == nfct(data,-1)))
v = eval_outputs(fct(n).shape)
v = eval_outputs(fct(n,-1).shape)
assert v==(2)
def test2b(self):
......@@ -1111,7 +1111,7 @@ class T_argmin_argmax(unittest.TestCase):
n = as_tensor_variable(data)
#test grad of argmin
utt.verify_grad(lambda v: argmin(v), [data])
utt.verify_grad(lambda v: argmin(v,axis=-1), [data])
utt.verify_grad(lambda v: argmin(v,axis=[0]), [data])
......@@ -1120,7 +1120,7 @@ class T_argmin_argmax(unittest.TestCase):
utt.verify_grad(lambda v: argmin(v.flatten()), [data])
try:
grad(argmin(n),n)
grad(argmin(n,axis=-1),n)
raise Exception('Expected an error')
except TypeError:
pass
......@@ -1130,7 +1130,7 @@ class T_argmin_argmax(unittest.TestCase):
n = as_tensor_variable(data)
#test grad of argmax
utt.verify_grad(lambda v: argmax(v), [data])
utt.verify_grad(lambda v: argmax(v, axis=-1), [data])
utt.verify_grad(lambda v: argmax(v,axis=[0]), [data])
......@@ -1139,7 +1139,7 @@ class T_argmin_argmax(unittest.TestCase):
utt.verify_grad(lambda v: argmax(v.flatten()), [data])
try:
grad(argmax(n),n)
grad(argmax(n, axis=-1),n)
raise Exception('Expected an error')
except TypeError:
pass
......@@ -1174,7 +1174,7 @@ class T_min_max(unittest.TestCase):
v = eval_outputs(fct(n,-1))
self.failUnless(numpy.all(v == nfct(data,-1)))
v = eval_outputs(fct(n).shape)
v = eval_outputs(fct(n,-1).shape)
assert v==(2)
def test2b(self):
......@@ -1294,7 +1294,7 @@ class T_min_max(unittest.TestCase):
#test grad of max
#axis is the last one
utt.verify_grad(lambda v: max(v), [data])
utt.verify_grad(lambda v: max(v,axis=-1), [data])
utt.verify_grad(lambda v: max(v,axis=[0]), [data])
check_grad_max(data,eval_outputs(grad(max(n,axis=0).sum(),n)),axis=0)
......@@ -1326,7 +1326,7 @@ class T_min_max(unittest.TestCase):
#test grad of min
#axis is the last one
utt.verify_grad(lambda v: min(v), [data])
utt.verify_grad(lambda v: min(v,axis=-1), [data])
utt.verify_grad(lambda v: min(v,axis=[0]), [data])
check_grad_min(data,eval_outputs(grad(min(n,axis=0).sum(),n)),axis=0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论