提交 0cf34d65 authored 作者: briancheung's avatar briancheung

Added test for gradient

上级 f7101c41
...@@ -106,6 +106,11 @@ class TestBatchedDot(TestCase): ...@@ -106,6 +106,11 @@ class TestBatchedDot(TestCase):
# Shape mismatch # Shape mismatch
self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2)) self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2))
def test_batched_dot_gradient(self):
theano.tests.unittest_tools.verify_grad(batched_dot,
[numpy.random.randn(5,7,2).astype(numpy.float32),
numpy.random.randn(5,2,6).astype(numpy.float32)])
def test_dot22(): def test_dot22():
def cmp(a_shp, b_shp): def cmp(a_shp, b_shp):
a0 = my_rand(*a_shp) a0 = my_rand(*a_shp)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论