提交 713fd0fe authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDotOp: implement and test infer_shape

上级 4ad36ddc
......@@ -212,6 +212,10 @@ class BatchedDotOp(GpuOp):
def c_code_cache_version(self):
return (1,)
def infer_shape(self, node, shapes):
xshp, yshp = shapes
return [xshp[:-1] + yshp[2:]]
batched_dot = BatchedDotOp()
"""
Call cublasSgemmBatched. Take 2 3d tensor as input.
......
......@@ -23,7 +23,7 @@ import theano.compile.mode
from theano.tensor.tests.test_blas import BaseGemv, TestBlasStrides, TestGer
from theano.sandbox.cuda.blas import gpu_gemv_no_inplace, gpu_gemv_inplace
from theano.sandbox.cuda.blas import gpu_ger_inplace, gpu_ger_no_inplace
from theano.sandbox.cuda.blas import batched_dot
from theano.sandbox.cuda.blas import batched_dot, BatchedDotOp
if theano.config.mode == 'FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
......@@ -115,6 +115,19 @@ class TestBatchedDot(TestCase):
mode=mode_with_gpu)
class TestBatchedDotInferShape(unittest_tools.InferShapeTester):
def test_infer_shape(self):
# only matrix/matrix is supported
admat = tensor.ftensor3()
bdmat = tensor.ftensor3()
admat_val = my_rand(7, 4, 5)
bdmat_val = my_rand(7, 5, 3)
self._compile_and_check([admat, bdmat],
[BatchedDotOp()(admat, bdmat)],
[admat_val, bdmat_val],
(BatchedDotOp))
def test_dot22():
def cmp(a_shp, b_shp):
a0 = my_rand(*a_shp)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论