提交 f2ed8c2f authored 作者: nouiz's avatar nouiz

Merge pull request #265 from delallea/fix_argmax_grad

Fix argmax grad
...@@ -1964,8 +1964,11 @@ class MaxAndArgmax(Op): ...@@ -1964,8 +1964,11 @@ class MaxAndArgmax(Op):
# Raise the g_max and xmax to the same number of dim as the input. # Raise the g_max and xmax to the same number of dim as the input.
pattern = [] pattern = []
out_dim = 0 out_dim = 0
for i in range(inp[0].ndim): if python_all(axis.data == range(x.ndim)):
if i == axis.data: # We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i == axis.data:
pattern.append('x') pattern.append('x')
else: else:
pattern.append(out_dim) pattern.append(out_dim)
......
...@@ -1546,9 +1546,9 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -1546,9 +1546,9 @@ class T_max_and_argmax(unittest.TestCase):
def check_grad_max(data, max_grad_data, axis=None): def check_grad_max(data, max_grad_data, axis=None):
""" """
Why this is needed? verify_grad is not enought? Why this is needed? verify_grad is not enough?
""" """
#This work only for axis in [0,None] # This works only for axis in [0, None].
assert axis in [0, None] assert axis in [0, None]
z = numpy.zeros_like(data) z = numpy.zeros_like(data)
z = z.flatten() z = z.flatten()
...@@ -1563,29 +1563,22 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -1563,29 +1563,22 @@ class T_max_and_argmax(unittest.TestCase):
z = z.reshape(data.shape) z = z.reshape(data.shape)
assert numpy.all(max_grad_data == z) assert numpy.all(max_grad_data == z)
#test grad of max for axis in (-1, 0, 1, None):
#axis is the last one for j in xrange(2):
utt.verify_grad(lambda v: max_and_argmax(v, axis=-1)[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v, axis=axis)[j],
utt.verify_grad(lambda v: max_and_argmax(v, axis=-1)[1], [data]) [data])
if axis != 1:
utt.verify_grad(lambda v: max_and_argmax(v, axis=[0])[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v.flatten(),
utt.verify_grad(lambda v: max_and_argmax(v, axis=[0])[1], [data]) axis=axis)[j],
check_grad_max(data, eval_outputs(grad( [data])
max_and_argmax(n, axis=0)[0].sum(), n)), axis=0) if axis in (0, None):
check_grad_max(data, eval_outputs(grad(
utt.verify_grad(lambda v: max_and_argmax(v, axis=[1])[0], [data]) max_and_argmax(n, axis=axis)[0].sum(), n)), axis=axis)
utt.verify_grad(lambda v: max_and_argmax(v, axis=[1])[1], [data]) check_grad_max(data, eval_outputs(grad(
#check_grad_max(data,eval_outputs(grad( max_and_argmax(n.flatten())[0], n)))
# max_and_argmax(n,axis=1)[0],n)),axis=1)
utt.verify_grad(lambda v: max_and_argmax(v.flatten())[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v.flatten())[1], [data])
check_grad_max(data, eval_outputs(grad(
max_and_argmax(n.flatten())[0], n)))
# Test 4d inner dimensions # Test 4d inner dimensions
data = numpy.random.rand(2, 3, 4, 5) data = numpy.random.rand(2, 3, 4, 5)
n = as_tensor_variable(data)
for i in [0, 1, 2, 3]: for i in [0, 1, 2, 3]:
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data]) utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) utt.verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论