提交 a93f1cbf authored 作者: Brian Cheung's avatar Brian Cheung

Merge pull request #1 from nouiz/briancheung-batched_dot_grad

Fix test in FAST_COMPILE
...@@ -43,6 +43,7 @@ mode_without_gpu.check_py_code = False ...@@ -43,6 +43,7 @@ mode_without_gpu.check_py_code = False
def my_rand(*shape): def my_rand(*shape):
return theano._asarray(numpy.random.rand(*shape), dtype='float32') return theano._asarray(numpy.random.rand(*shape), dtype='float32')
class TestBatchedDot(TestCase): class TestBatchedDot(TestCase):
def test_batched_dot_correctness(self): def test_batched_dot_correctness(self):
...@@ -107,9 +108,12 @@ class TestBatchedDot(TestCase): ...@@ -107,9 +108,12 @@ class TestBatchedDot(TestCase):
self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2)) self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2))
def test_batched_dot_gradient(self): def test_batched_dot_gradient(self):
theano.tests.unittest_tools.verify_grad(batched_dot, theano.tests.unittest_tools.verify_grad(
batched_dot,
[numpy.random.randn(5,7,2).astype(numpy.float32), [numpy.random.randn(5,7,2).astype(numpy.float32),
numpy.random.randn(5,2,6).astype(numpy.float32)]) numpy.random.randn(5,2,6).astype(numpy.float32)],
mode=mode_with_gpu)
def test_dot22(): def test_dot22():
def cmp(a_shp, b_shp): def cmp(a_shp, b_shp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论