提交 f2ed8c2f authored 作者: nouiz's avatar nouiz

Merge pull request #265 from delallea/fix_argmax_grad

Fix argmax grad
......@@ -1964,8 +1964,11 @@ class MaxAndArgmax(Op):
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
for i in range(inp[0].ndim):
if i == axis.data:
if python_all(axis.data == range(x.ndim)):
# We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i == axis.data:
pattern.append('x')
else:
pattern.append(out_dim)
......
......@@ -1546,9 +1546,9 @@ class T_max_and_argmax(unittest.TestCase):
def check_grad_max(data, max_grad_data, axis=None):
"""
Why this is needed? verify_grad is not enought?
Why this is needed? verify_grad is not enough?
"""
#This work only for axis in [0,None]
# This works only for axis in [0, None].
assert axis in [0, None]
z = numpy.zeros_like(data)
z = z.flatten()
......@@ -1563,29 +1563,22 @@ class T_max_and_argmax(unittest.TestCase):
z = z.reshape(data.shape)
assert numpy.all(max_grad_data == z)
#test grad of max
#axis is the last one
utt.verify_grad(lambda v: max_and_argmax(v, axis=-1)[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=-1)[1], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[0])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[0])[1], [data])
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n, axis=0)[0].sum(), n)), axis=0)
utt.verify_grad(lambda v: max_and_argmax(v, axis=[1])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[1])[1], [data])
#check_grad_max(data,eval_outputs(grad(
# max_and_argmax(n,axis=1)[0],n)),axis=1)
utt.verify_grad(lambda v: max_and_argmax(v.flatten())[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v.flatten())[1], [data])
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n.flatten())[0], n)))
for axis in (-1, 0, 1, None):
for j in xrange(2):
utt.verify_grad(lambda v: max_and_argmax(v, axis=axis)[j],
[data])
if axis != 1:
utt.verify_grad(lambda v: max_and_argmax(v.flatten(),
axis=axis)[j],
[data])
if axis in (0, None):
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n, axis=axis)[0].sum(), n)), axis=axis)
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n.flatten())[0], n)))
# Test 4d inner dimensions
data = numpy.random.rand(2, 3, 4, 5)
n = as_tensor_variable(data)
for i in [0, 1, 2, 3]:
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论