提交 d6b3dff4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3947 from Sentient07/issue-#3944

Renamed BatchedDotOp to GpuBatchedDot
.. _libdoc_cuda_op: .. _libdoc_cuda_op:
====================================================== ======================================================
...@@ -19,7 +20,7 @@ Blas Op ...@@ -19,7 +20,7 @@ Blas Op
.. automodule:: theano.sandbox.cuda.blas .. automodule:: theano.sandbox.cuda.blas
:members: :members:
.. autoclass:: theano.sandbox.cuda.blas.BatchedDotOp .. autoclass:: theano.sandbox.cuda.blas.GpuBatchedDot
Nnet Op Nnet Op
======= =======
......
...@@ -15,7 +15,7 @@ from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, ...@@ -15,7 +15,7 @@ from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
from theano.tensor import as_tensor_variable from theano.tensor import as_tensor_variable
class BatchedDotOp(GpuOp): class GpuBatchedDot(GpuOp):
__props__ = () __props__ = ()
def make_node(self, inp1, inp2): def make_node(self, inp1, inp2):
...@@ -216,7 +216,8 @@ class BatchedDotOp(GpuOp): ...@@ -216,7 +216,8 @@ class BatchedDotOp(GpuOp):
xshp, yshp = shapes xshp, yshp = shapes
return [xshp[:-1] + yshp[2:]] return [xshp[:-1] + yshp[2:]]
batched_dot = BatchedDotOp() batched_dot = GpuBatchedDot()
BatchedDotOp = GpuBatchedDot()
""" """
Call cublasSgemmBatched. Take 2 3d tensor as input. Call cublasSgemmBatched. Take 2 3d tensor as input.
""" """
......
...@@ -33,7 +33,7 @@ from theano.sandbox.cuda.basic_ops import ( ...@@ -33,7 +33,7 @@ from theano.sandbox.cuda.basic_ops import (
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.blas import ( from theano.sandbox.cuda.blas import (
gpu_dot22, gpu_dot22scalar, gpu_gemm_inplace, gpu_gemm_no_inplace, GpuConv, gpu_dot22, gpu_dot22scalar, gpu_gemm_inplace, gpu_gemm_no_inplace, GpuConv,
BatchedDotOp, GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights, GpuBatchedDot, GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights,
GpuCorr3dMM, GpuCorr3dMM_gradInputs, GpuCorr3dMM_gradWeights) GpuCorr3dMM, GpuCorr3dMM_gradInputs, GpuCorr3dMM_gradWeights)
from theano.sandbox.cuda.blas import gpu_gemv_inplace from theano.sandbox.cuda.blas import gpu_gemv_inplace
...@@ -629,7 +629,7 @@ def local_gpu_batched_dot(node): ...@@ -629,7 +629,7 @@ def local_gpu_batched_dot(node):
x_ = x_.dimshuffle(0, "x", 1) x_ = x_.dimshuffle(0, "x", 1)
if y.ndim == 2: if y.ndim == 2:
y_ = y_.dimshuffle(0, 1, "x") y_ = y_.dimshuffle(0, 1, "x")
z = BatchedDotOp()(as_cuda_ndarray_variable(x_), z = GpuBatchedDot()(as_cuda_ndarray_variable(x_),
as_cuda_ndarray_variable(y_)) as_cuda_ndarray_variable(y_))
# unpad z shape # unpad z shape
if x.ndim == 2: if x.ndim == 2:
......
...@@ -23,7 +23,7 @@ import theano.compile.mode ...@@ -23,7 +23,7 @@ import theano.compile.mode
from theano.tensor.tests.test_blas import BaseGemv, TestBlasStrides, TestGer from theano.tensor.tests.test_blas import BaseGemv, TestBlasStrides, TestGer
from theano.sandbox.cuda.blas import gpu_gemv_no_inplace, gpu_gemv_inplace from theano.sandbox.cuda.blas import gpu_gemv_no_inplace, gpu_gemv_inplace
from theano.sandbox.cuda.blas import gpu_ger_inplace, gpu_ger_no_inplace from theano.sandbox.cuda.blas import gpu_ger_inplace, gpu_ger_no_inplace
from theano.sandbox.cuda.blas import batched_dot, BatchedDotOp from theano.sandbox.cuda.blas import batched_dot, GpuBatchedDot
if theano.config.mode == 'FAST_COMPILE': if theano.config.mode == 'FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu') mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
...@@ -121,9 +121,9 @@ class TestBatchedDot(unittest_tools.InferShapeTester): ...@@ -121,9 +121,9 @@ class TestBatchedDot(unittest_tools.InferShapeTester):
admat_val = my_rand(7, 4, 5) admat_val = my_rand(7, 4, 5)
bdmat_val = my_rand(7, 5, 3) bdmat_val = my_rand(7, 5, 3)
self._compile_and_check([admat, bdmat], self._compile_and_check([admat, bdmat],
[BatchedDotOp()(admat, bdmat)], [GpuBatchedDot()(admat, bdmat)],
[admat_val, bdmat_val], [admat_val, bdmat_val],
BatchedDotOp) GpuBatchedDot)
def test_dot22(): def test_dot22():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论