提交 92ec8cb2 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

More tests of argmax gradient

Now also testing gradient with axis=None (failing as of this commit).
上级 bd3cac42
...@@ -1546,9 +1546,9 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -1546,9 +1546,9 @@ class T_max_and_argmax(unittest.TestCase):
def check_grad_max(data, max_grad_data, axis=None): def check_grad_max(data, max_grad_data, axis=None):
""" """
Why this is needed? verify_grad is not enought? Why this is needed? verify_grad is not enough?
""" """
#This work only for axis in [0,None] # This works only for axis in [0, None].
assert axis in [0, None] assert axis in [0, None]
z = numpy.zeros_like(data) z = numpy.zeros_like(data)
z = z.flatten() z = z.flatten()
...@@ -1563,29 +1563,22 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -1563,29 +1563,22 @@ class T_max_and_argmax(unittest.TestCase):
z = z.reshape(data.shape) z = z.reshape(data.shape)
assert numpy.all(max_grad_data == z) assert numpy.all(max_grad_data == z)
#test grad of max for axis in (-1, 0, 1, None):
#axis is the last one for j in xrange(2):
utt.verify_grad(lambda v: max_and_argmax(v, axis=-1)[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v, axis=axis)[j],
utt.verify_grad(lambda v: max_and_argmax(v, axis=-1)[1], [data]) [data])
if axis != 1:
utt.verify_grad(lambda v: max_and_argmax(v, axis=[0])[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v.flatten(),
utt.verify_grad(lambda v: max_and_argmax(v, axis=[0])[1], [data]) axis=axis)[j],
check_grad_max(data, eval_outputs(grad( [data])
max_and_argmax(n, axis=0)[0].sum(), n)), axis=0) if axis in (0, None):
check_grad_max(data, eval_outputs(grad(
utt.verify_grad(lambda v: max_and_argmax(v, axis=[1])[0], [data]) max_and_argmax(n, axis=axis)[0].sum(), n)), axis=axis)
utt.verify_grad(lambda v: max_and_argmax(v, axis=[1])[1], [data]) check_grad_max(data, eval_outputs(grad(
#check_grad_max(data,eval_outputs(grad( max_and_argmax(n.flatten())[0], n)))
# max_and_argmax(n,axis=1)[0],n)),axis=1)
utt.verify_grad(lambda v: max_and_argmax(v.flatten())[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v.flatten())[1], [data])
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n.flatten())[0], n)))
# Test 4d inner dimensions # Test 4d inner dimensions
data = numpy.random.rand(2, 3, 4, 5) data = numpy.random.rand(2, 3, 4, 5)
n = as_tensor_variable(data)
for i in [0, 1, 2, 3]: for i in [0, 1, 2, 3]:
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论