提交 e68999e8 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDotOp: simplify tests

上级 f6c48641
......@@ -44,7 +44,7 @@ def my_rand(*shape):
return theano._asarray(numpy.random.rand(*shape), dtype='float32')
class TestBatchedDot(TestCase):
class TestBatchedDot(unittest_tools.InferShapeTester):
def test_batched_dot_correctness(self):
......@@ -114,8 +114,6 @@ class TestBatchedDot(TestCase):
numpy.random.randn(5,2,6).astype(numpy.float32)],
mode=mode_with_gpu)
class TestBatchedDotInferShape(unittest_tools.InferShapeTester):
def test_infer_shape(self):
# only matrix/matrix is supported
admat = tensor.ftensor3()
......@@ -125,7 +123,7 @@ class TestBatchedDotInferShape(unittest_tools.InferShapeTester):
self._compile_and_check([admat, bdmat],
[BatchedDotOp()(admat, bdmat)],
[admat_val, bdmat_val],
(BatchedDotOp))
BatchedDotOp)
def test_dot22():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论