提交 0bc12fe9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for the fortran order in gemv (and a test for it).

上级 d12f4aea
......@@ -42,10 +42,17 @@ def gemv(alpha, A, x, beta, y):
assert A.shape[0] == x.shape[0]
assert A.shape[1] == y.shape[0]
if A.strides[0] == 1:
n, m = 0, 1
trans = 't'
else:
n, m = 1, 0
trans = 'n'
handle = scikits.cuda.misc._global_cublas_handle
cublas.cublasSgemv(handle, 'n', A.shape[1], A.shape[0], alpha,
A.gpudata, A.strides[0],
cublas.cublasSgemv(handle, trans, A.shape[n], A.shape[m], alpha,
A.gpudata, A.strides[m],
x.gpudata, x.strides[0],
beta, y.gpudata, y.strides[0])
......
......@@ -13,6 +13,8 @@ import theano.sandbox.cuda as cuda_ndarray
if cuda_ndarray.cuda_available == False:
raise SkipTest('Optional package cuda disabled')
from theano.sandbox.cuda.basic_ops import (GpuDimShuffle,
as_cuda_ndarray_variable)
from theano.sandbox.cuda.blocksparse import (sparse_block_dot_SS,
sparse_block_gemv_ss,
sparse_block_gemv_ss_inplace,
......@@ -55,6 +57,7 @@ def blocksparse(W, h, iIdx, b, oIdx):
return o
def test_blocksparse():
b = tensor.fmatrix()
W = tensor.ftensor4()
......@@ -74,6 +77,29 @@ def test_blocksparse():
utt.assert_allclose(ref_out, th_out)
# test the fortan order for W (which can happen in the grad for some graphs).
def test_blocksparseF():
b = tensor.fmatrix()
W = tensor.ftensor4()
h = tensor.fmatrix()
iIdx = tensor.lvector()
oIdx = tensor.lvector()
o = sparse_block_dot_SS(GpuDimShuffle((False, False, False, False),
(0, 1, 3, 2))(
as_cuda_ndarray_variable(W)),
h, iIdx, b, oIdx)
f = theano.function([W, h, iIdx, b, oIdx], o)
W_val, h_val, iIdx_val, b_val, oIdx_val = blocksparse_data()
th_out = f(numpy.swapaxes(W_val, 2, 3), h_val, iIdx_val, b_val, oIdx_val)
ref_out = blocksparse(W_val, h_val, iIdx_val, b_val, oIdx_val)
utt.assert_allclose(ref_out, th_out)
def test_blocksparse_grad():
h_val = randn(2, 3).astype('float32')
iIdx_val = numpy.random.permutation(3)[:2]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论