提交 92ec8cb2 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

More tests of argmax gradient

Now also testing gradient with axis=None (failing as of this commit).
上级 bd3cac42
......@@ -1546,9 +1546,9 @@ class T_max_and_argmax(unittest.TestCase):
def check_grad_max(data, max_grad_data, axis=None):
"""
Why this is needed? verify_grad is not enought?
Why this is needed? verify_grad is not enough?
"""
#This work only for axis in [0,None]
# This works only for axis in [0, None].
assert axis in [0, None]
z = numpy.zeros_like(data)
z = z.flatten()
......@@ -1563,29 +1563,22 @@ class T_max_and_argmax(unittest.TestCase):
z = z.reshape(data.shape)
assert numpy.all(max_grad_data == z)
#test grad of max
#axis is the last one
utt.verify_grad(lambda v: max_and_argmax(v, axis=-1)[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=-1)[1], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[0])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[0])[1], [data])
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n, axis=0)[0].sum(), n)), axis=0)
utt.verify_grad(lambda v: max_and_argmax(v, axis=[1])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[1])[1], [data])
#check_grad_max(data,eval_outputs(grad(
# max_and_argmax(n,axis=1)[0],n)),axis=1)
utt.verify_grad(lambda v: max_and_argmax(v.flatten())[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v.flatten())[1], [data])
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n.flatten())[0], n)))
for axis in (-1, 0, 1, None):
for j in xrange(2):
utt.verify_grad(lambda v: max_and_argmax(v, axis=axis)[j],
[data])
if axis != 1:
utt.verify_grad(lambda v: max_and_argmax(v.flatten(),
axis=axis)[j],
[data])
if axis in (0, None):
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n, axis=axis)[0].sum(), n)), axis=axis)
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n.flatten())[0], n)))
# Test 4d inner dimensions
data = numpy.random.rand(2, 3, 4, 5)
n = as_tensor_variable(data)
for i in [0, 1, 2, 3]:
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论