提交 472afaf3 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Test that an error is raised when calling tensor.grad on a non-scalar cost.

Finishes closing of #539.
上级 0b186aec
...@@ -2414,6 +2414,17 @@ class test_grad(unittest.TestCase): ...@@ -2414,6 +2414,17 @@ class test_grad(unittest.TestCase):
self.failUnless((f(a) == 0).all()) # Zero gradient. self.failUnless((f(a) == 0).all()) # Zero gradient.
self.failUnless(a.shape == f(a).shape) # With proper shape. self.failUnless(a.shape == f(a).shape) # With proper shape.
def test_cost_is_scalar(self):
'''grad: Test that a non-scalar cost raises a TypeError'''
s = scalar()
v = vector()
m = matrix()
# grad(v,...) and grad(m,...) should fail
self.assertRaises(TypeError, grad, v, s)
self.assertRaises(TypeError, grad, v, m)
self.assertRaises(TypeError, grad, m, s)
self.assertRaises(TypeError, grad, m, v)
class T_op_cache(unittest.TestCase): class T_op_cache(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论