提交 4a961015 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

better test

上级 619b7be6
...@@ -2021,14 +2021,24 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -2021,14 +2021,24 @@ class T_max_and_argmax(unittest.TestCase):
assert tuple(v) == numpy.max(data, np_axis).shape assert tuple(v) == numpy.max(data, np_axis).shape
def test_arg_grad(self): def test_arg_grad(self):
"""
The test checks if computing the gradient of argmax(x).sum() fails
because there is no differentiable path from cost to the input and
not because of an error of the grad method of the op
"""
x = matrix() x = matrix()
cost = argmax(x, axis=0).sum() cost = argmax(x, axis=0).sum()
value_error_raised = False
try: try:
gx = grad(cost, x) gx = grad(cost, x)
except ValueError: except ValueError:
# It is the error saying there is no differentiable path to the # It is the error saying there is no differentiable path to the
# input # input
pass value_error_raised = True
if value_error_raised:
raise ValueError(('Test failed because exception saying '
'no differentiable path found was not '
'raised'))
def test_grad(self): def test_grad(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论