@@ -2019,27 +2019,16 @@ class T_max_and_argmax(unittest.TestCase):
...
@@ -2019,27 +2019,16 @@ class T_max_and_argmax(unittest.TestCase):
deftest_arg_grad(self):
deftest_arg_grad(self):
"""
"""
The test checks if computing the gradient of argmax(x).sum() fails
The test checks that the gradient of argmax(x).sum() is 0
because there is no differentiable path from cost to the input and
not because of an error of the grad method of the op
"""
"""
raiseKnownFailureTest("The desired behavior of the grad method in this case is currently under debate. In any case, the result should be to return NaN or 0, not to report a disconnected input.")
x=matrix()
x=matrix()
cost=argmax(x,axis=0).sum()
cost=argmax(x,axis=0).sum()
value_error_raised=False
value_error_raised=False
try:
gx=grad(cost,x)
gx=grad(cost,x)
val=tensor.get_constant_value(gx)
exceptValueError:
assertval==0.0
# It is the error saying there is no differentiable path to the
# input
value_error_raised=True
ifnotvalue_error_raised:
raiseValueError(('Test failed because exception saying '