提交 4ae9cf4b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Explicitly calls min, max, etc. with axis=-1 to silence warning.

上级 db5a9c1e
...@@ -890,7 +890,7 @@ def crossentropy_softmax_max_and_argmax_1hot_with_bias(x, b, y_idx, **kwargs): ...@@ -890,7 +890,7 @@ def crossentropy_softmax_max_and_argmax_1hot_with_bias(x, b, y_idx, **kwargs):
the appropriate information (i.e. the max probability)? the appropriate information (i.e. the max probability)?
""" """
(xent, softmax) = crossentropy_softmax_1hot_with_bias(x, b, y_idx, **kwargs) (xent, softmax) = crossentropy_softmax_1hot_with_bias(x, b, y_idx, **kwargs)
(max_pr, argmax) = tensor.max_and_argmax(softmax) (max_pr, argmax) = tensor.max_and_argmax(softmax, axis=-1)
return (xent, softmax, max_pr, argmax) return (xent, softmax, max_pr, argmax)
def crossentropy_softmax_max_and_argmax_1hot(x, y_idx, **kwargs): def crossentropy_softmax_max_and_argmax_1hot(x, y_idx, **kwargs):
b = tensor.zeros_like(x[0,:]) b = tensor.zeros_like(x[0,:])
......
...@@ -875,10 +875,10 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -875,10 +875,10 @@ class T_max_and_argmax(unittest.TestCase):
def test2(self): def test2(self):
data = numpy.random.rand(2,3) data = numpy.random.rand(2,3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
v,i = eval_outputs(max_and_argmax(n)) v,i = eval_outputs(max_and_argmax(n,-1))
self.failUnless(numpy.all(v == numpy.max(data,-1))) self.failUnless(numpy.all(v == numpy.max(data,-1)))
self.failUnless(numpy.all(i == numpy.argmax(data,-1))) self.failUnless(numpy.all(i == numpy.argmax(data,-1)))
v = eval_outputs(max_and_argmax(n)[0].shape) v = eval_outputs(max_and_argmax(n,-1)[0].shape)
assert v==(2) assert v==(2)
def test2b(self): def test2b(self):
...@@ -977,8 +977,8 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -977,8 +977,8 @@ class T_max_and_argmax(unittest.TestCase):
#test grad of max #test grad of max
#axis is the last one #axis is the last one
utt.verify_grad(lambda v: max_and_argmax(v)[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v,axis=-1)[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v)[1], [data]) utt.verify_grad(lambda v: max_and_argmax(v,axis=-1)[1], [data])
utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[1], [data]) utt.verify_grad(lambda v: max_and_argmax(v,axis=[0])[1], [data])
...@@ -1022,9 +1022,9 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -1022,9 +1022,9 @@ class T_argmin_argmax(unittest.TestCase):
for fct,nfct in [(argmax,numpy.argmax),(argmin,numpy.argmin)]: for fct,nfct in [(argmax,numpy.argmax),(argmin,numpy.argmin)]:
data = numpy.random.rand(2,3) data = numpy.random.rand(2,3)
n = as_tensor_variable(data) n = as_tensor_variable(data)
i = eval_outputs(fct(n)) i = eval_outputs(fct(n,-1))
self.failUnless(numpy.all(i == nfct(data,-1))) self.failUnless(numpy.all(i == nfct(data,-1)))
v = eval_outputs(fct(n).shape) v = eval_outputs(fct(n,-1).shape)
assert v==(2) assert v==(2)
def test2b(self): def test2b(self):
...@@ -1111,7 +1111,7 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -1111,7 +1111,7 @@ class T_argmin_argmax(unittest.TestCase):
n = as_tensor_variable(data) n = as_tensor_variable(data)
#test grad of argmin #test grad of argmin
utt.verify_grad(lambda v: argmin(v), [data]) utt.verify_grad(lambda v: argmin(v,axis=-1), [data])
utt.verify_grad(lambda v: argmin(v,axis=[0]), [data]) utt.verify_grad(lambda v: argmin(v,axis=[0]), [data])
...@@ -1120,7 +1120,7 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -1120,7 +1120,7 @@ class T_argmin_argmax(unittest.TestCase):
utt.verify_grad(lambda v: argmin(v.flatten()), [data]) utt.verify_grad(lambda v: argmin(v.flatten()), [data])
try: try:
grad(argmin(n),n) grad(argmin(n,axis=-1),n)
raise Exception('Expected an error') raise Exception('Expected an error')
except TypeError: except TypeError:
pass pass
...@@ -1130,7 +1130,7 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -1130,7 +1130,7 @@ class T_argmin_argmax(unittest.TestCase):
n = as_tensor_variable(data) n = as_tensor_variable(data)
#test grad of argmax #test grad of argmax
utt.verify_grad(lambda v: argmax(v), [data]) utt.verify_grad(lambda v: argmax(v, axis=-1), [data])
utt.verify_grad(lambda v: argmax(v,axis=[0]), [data]) utt.verify_grad(lambda v: argmax(v,axis=[0]), [data])
...@@ -1139,7 +1139,7 @@ class T_argmin_argmax(unittest.TestCase): ...@@ -1139,7 +1139,7 @@ class T_argmin_argmax(unittest.TestCase):
utt.verify_grad(lambda v: argmax(v.flatten()), [data]) utt.verify_grad(lambda v: argmax(v.flatten()), [data])
try: try:
grad(argmax(n),n) grad(argmax(n, axis=-1),n)
raise Exception('Expected an error') raise Exception('Expected an error')
except TypeError: except TypeError:
pass pass
...@@ -1174,7 +1174,7 @@ class T_min_max(unittest.TestCase): ...@@ -1174,7 +1174,7 @@ class T_min_max(unittest.TestCase):
v = eval_outputs(fct(n,-1)) v = eval_outputs(fct(n,-1))
self.failUnless(numpy.all(v == nfct(data,-1))) self.failUnless(numpy.all(v == nfct(data,-1)))
v = eval_outputs(fct(n).shape) v = eval_outputs(fct(n,-1).shape)
assert v==(2) assert v==(2)
def test2b(self): def test2b(self):
...@@ -1294,7 +1294,7 @@ class T_min_max(unittest.TestCase): ...@@ -1294,7 +1294,7 @@ class T_min_max(unittest.TestCase):
#test grad of max #test grad of max
#axis is the last one #axis is the last one
utt.verify_grad(lambda v: max(v), [data]) utt.verify_grad(lambda v: max(v,axis=-1), [data])
utt.verify_grad(lambda v: max(v,axis=[0]), [data]) utt.verify_grad(lambda v: max(v,axis=[0]), [data])
check_grad_max(data,eval_outputs(grad(max(n,axis=0).sum(),n)),axis=0) check_grad_max(data,eval_outputs(grad(max(n,axis=0).sum(),n)),axis=0)
...@@ -1326,7 +1326,7 @@ class T_min_max(unittest.TestCase): ...@@ -1326,7 +1326,7 @@ class T_min_max(unittest.TestCase):
#test grad of min #test grad of min
#axis is the last one #axis is the last one
utt.verify_grad(lambda v: min(v), [data]) utt.verify_grad(lambda v: min(v,axis=-1), [data])
utt.verify_grad(lambda v: min(v,axis=[0]), [data]) utt.verify_grad(lambda v: min(v,axis=[0]), [data])
check_grad_min(data,eval_outputs(grad(min(n,axis=0).sum(),n)),axis=0) check_grad_min(data,eval_outputs(grad(min(n,axis=0).sum(),n)),axis=0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论