提交 8f9c2a12 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use gemm_batched in the python code.

上级 29db8ffb
......@@ -14,6 +14,7 @@ import theano.misc.pycuda_init
from theano.misc.pycuda_init import pycuda_available
if pycuda_available:
import pycuda.gpuarray
from theano.misc.pycuda_utils import to_cudandarray
try:
import scikits.cuda
......@@ -25,16 +26,16 @@ except ImportError:
scikits_cuda_available = False
def gemm_batched(Al, Bl, Cl, m, n, k, lda, ldb, ldc,
alpha=numpy.float32(1.0), beta=numpy.float32(1.0)):
def gemm_batched(tA, tB, m, n, k, Al, lda, Bl, ldb, Cl, ldc,
alpha=numpy.float32(1.0), beta=numpy.float32(0.0)):
assert Al.shape[0] == Bl.shape[0]
assert Al.shape[0] == Cl.shape[0]
handle = scikits.cuda.misc._global_cublas_handle
cublas.cublasSgemmBatched(handle, 'n', 'n', m, n, k, alpha,
Bl.gpudata, ldb, Al.gpudata, lda,
beta, Cl.gpuadata, ldc,
cublas.cublasSgemmBatched(handle, tA, tB, m, n, k, alpha,
Al.ptr, lda, Bl.ptr, ldb,
beta, Cl.ptr, ldc,
Cl.shape[0])
......@@ -56,7 +57,6 @@ def gemv(alpha, A, x, beta, y):
x.gpudata, x.strides[0],
beta, y.gpudata, y.strides[0])
def ger(alpha, x, y, A):
assert A.shape[1] == x.shape[0]
assert A.shape[0] == y.shape[0]
......@@ -69,14 +69,6 @@ def ger(alpha, x, y, A):
A.gpudata, A.strides[0])
def bptr(a):
assert (a.ndim == 3 and a.strides[2] == 1)
return pycuda.gpuarray.arange(a.ptr,
a.ptr + a.shape[0] * a.strides[0] * 4,
a.strides[0] * 4,
dtype=cublas.ctypes.c_void_p)
class SparseBlockGemvSS(GpuOp):
def __init__(self, inplace):
self.inplace = inplace
......@@ -115,12 +107,41 @@ class SparseBlockGemvSS(GpuOp):
if not self.inplace:
o = o.copy()
dd = (o.shape[0] * h.shape[0],)
weightHostB = numpy.empty(dd, dtype='intp')
outputHostB = numpy.empty(dd, dtype='intp')
inputHostB = numpy.empty(dd, dtype='intp')
outputBatched = pycuda.gpuarray.GPUArray((h.shape[0], o.shape[0], o.shape[1]), dtype='float32')
k = 0
for j in range(o.shape[0]):
out_id = outputIdx[j]
for i in range(h.shape[0]):
inp_id = inputIdx[i]
gemv(numpy.float32(1.0), W[inp_id, out_id],
h[i], numpy.float32(1.0), o[j])
weightHostB[k] = W[inp_id, out_id].gpudata
outputHostB[k] = outputBatched[i, j].ptr
inputHostB[k] = h[i].gpudata
k += 1
weightB = pycuda.gpuarray.to_gpu(weightHostB)
inputB = pycuda.gpuarray.to_gpu(inputHostB)
outputB = pycuda.gpuarray.to_gpu(outputHostB)
tA = 'n'
lda = W.strides[2]
if lda == 1:
tA = 't'
lda = W.strides[3]
gemm_batched(tA, 'n', o.shape[1], 1, h.shape[1],
weightB, lda, inputB, h.strides[0],
outputB, o.strides[0],
beta=numpy.asarray(0.0, dtype='float32'))
outputBatchedG = to_cudandarray(outputBatched)
o += outputBatchedG.reduce_sum([1, 0, 0])
out[0] = o
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论