提交 d1f762aa authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for dimensions of size 1 in all cases.

上级 8e23c533
...@@ -216,13 +216,16 @@ CudaNdarray_HOST_DIMS(%(o)s)[2], ...@@ -216,13 +216,16 @@ CudaNdarray_HOST_DIMS(%(o)s)[2],
transA = CUBLAS_OP_T; transA = CUBLAS_OP_T;
lda = CudaNdarray_HOST_STRIDES(%(W)s)[3]; lda = CudaNdarray_HOST_STRIDES(%(W)s)[3];
} }
if (lda == 0) lda = 1;
err = cublasSgemmBatched(handle, transA, CUBLAS_OP_N, err = cublasSgemmBatched(handle, transA, CUBLAS_OP_N,
CudaNdarray_HOST_DIMS(%(o)s)[2], 1, CudaNdarray_HOST_DIMS(%(o)s)[2], 1,
CudaNdarray_HOST_DIMS(%(h)s)[2], &alpha, CudaNdarray_HOST_DIMS(%(h)s)[2], &alpha,
%(name)s_W_list, lda, %(name)s_inp_list, %(name)s_W_list, lda, %(name)s_inp_list,
CudaNdarray_HOST_STRIDES(%(h)s)[1], CudaNdarray_HOST_STRIDES(%(h)s)[1] == 0 ?
1 : CudaNdarray_HOST_STRIDES(%(h)s)[1],
&beta, %(name)s_out_list, &beta, %(name)s_out_list,
CudaNdarray_HOST_STRIDES(%(o)s)[1], CudaNdarray_HOST_STRIDES(%(o)s)[1] == 0 ?
1 : CudaNdarray_HOST_STRIDES(%(o)s)[1],
CudaNdarray_HOST_DIMS(%(o)s)[1] * CudaNdarray_HOST_DIMS(%(o)s)[1] *
CudaNdarray_HOST_DIMS(%(h)s)[1] * CudaNdarray_HOST_DIMS(%(h)s)[1] *
CudaNdarray_HOST_DIMS(%(o)s)[0]); CudaNdarray_HOST_DIMS(%(o)s)[0]);
...@@ -256,7 +259,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[2]); ...@@ -256,7 +259,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[2]);
W=W, fail=sub['fail'], name=nodename) W=W, fail=sub['fail'], name=nodename)
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (6,)
def grad(self, inputs, grads): def grad(self, inputs, grads):
o, W, h, inputIdx, outputIdx = inputs o, W, h, inputIdx, outputIdx = inputs
...@@ -442,12 +445,17 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1], ...@@ -442,12 +445,17 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
} }
{ {
cublasStatus_t err; cublasStatus_t err;
int str_y = CudaNdarray_HOST_STRIDES(%(y)s)[1];
if (str_y == 0) str_y = 1;
int str_x = CudaNdarray_HOST_STRIDES(%(x)s)[1];
if (str_x == 0) str_x = 1;
int str_out = CudaNdarray_HOST_STRIDES(%(out)s)[2];
if (str_out == 0) str_out = 1;
err = cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_T, err = cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_T,
CudaNdarray_HOST_DIMS(%(y)s)[2], CudaNdarray_HOST_DIMS(%(x)s)[2], 1, CudaNdarray_HOST_DIMS(%(y)s)[2], CudaNdarray_HOST_DIMS(%(x)s)[2], 1,
(float *)PyArray_GETPTR1(%(alpha)s, 0), %(name)s_y_list, (float *)PyArray_GETPTR1(%(alpha)s, 0), %(name)s_y_list, str_y,
CudaNdarray_HOST_STRIDES(%(y)s)[1], %(name)s_x_list, %(name)s_x_list, str_x, (float *)PyArray_GETPTR1(%(beta)s, 0),
CudaNdarray_HOST_STRIDES(%(x)s)[1], (float *)PyArray_GETPTR1(%(beta)s, 0), %(name)s_out_list, str_out,
%(name)s_out_list, CudaNdarray_HOST_STRIDES(%(out)s)[2],
CudaNdarray_HOST_DIMS(%(x)s)[0] * CudaNdarray_HOST_DIMS(%(x)s)[0] *
CudaNdarray_HOST_DIMS(%(x)s)[1] * CudaNdarray_HOST_DIMS(%(x)s)[1] *
CudaNdarray_HOST_DIMS(%(y)s)[1]); CudaNdarray_HOST_DIMS(%(y)s)[1]);
...@@ -459,7 +467,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1], ...@@ -459,7 +467,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
alpha=alpha, beta=beta, fail=sub['fail']) alpha=alpha, beta=beta, fail=sub['fail'])
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
sparse_block_outer_ss = SparseBlockOuterSS(False) sparse_block_outer_ss = SparseBlockOuterSS(False)
......
...@@ -122,6 +122,22 @@ def test_blocksparse_grad(): ...@@ -122,6 +122,22 @@ def test_blocksparse_grad():
utt.verify_grad(f, [b_val, h_val, W_val]) utt.verify_grad(f, [b_val, h_val, W_val])
def test_blocksparse_grad_1():
# This tests that we correctly handle cases where dimensions are 1.
h_val = randn(1, 1, 1).astype('float32')
iIdx_val = numpy.random.permutation(1)[:1][None, :]
oIdx_val = numpy.random.permutation(1)[:1][None, :]
W_val = randn(1, 1, 1, 1).astype('float32')
b_val = randn(1, 1).astype('float32')
iIdx = theano.tensor.constant(iIdx_val)
oIdx = theano.tensor.constant(oIdx_val)
def f(b, h, W):
return sparse_block_gemv_ss(b.take(oIdx, axis=0), W, h, iIdx, oIdx)
utt.verify_grad(f, [b_val, h_val, W_val])
def test_blocksparse_grad_shape(): def test_blocksparse_grad_shape():
b = tensor.fmatrix() b = tensor.fmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论