提交 b474b0b2 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #870 from pascanur/maxArgmax_bug

Max argmax bug
......@@ -2300,6 +2300,9 @@ class MaxAndArgmax(Op):
x, axis = inp
g_max, g_max_idx = grads
# Check to see if the gradient on max is None
if g_max is None:
return None, None
xmax = max(x, axis)
# Raise the g_max and xmax to the same number of dim as the input.
......
......@@ -2020,6 +2020,27 @@ class T_max_and_argmax(unittest.TestCase):
v = eval_outputs(max_and_argmax(n, axis)[0].shape)
assert tuple(v) == numpy.max(data, np_axis).shape
def test_arg_grad(self):
"""
The test checks if computing the gradient of argmax(x).sum() fails
because there is no differentiable path from cost to the input and
not because of an error of the grad method of the op
"""
x = matrix()
cost = argmax(x, axis=0).sum()
value_error_raised = False
try:
gx = grad(cost, x)
except ValueError:
# It is the error saying there is no differentiable path to the
# input
value_error_raised = True
if not value_error_raised:
raise ValueError(('Test failed because exception saying '
'no differentiable path found was not '
'raised'))
def test_grad(self):
data = rand(2, 3)
n = as_tensor_variable(data)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论