提交 983ed2a2 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the last blocksparse problem.

上级 e399939d
...@@ -41,12 +41,12 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W, ...@@ -41,12 +41,12 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
/* Prepare lists for the batch */ /* Prepare lists for the batch */
size_t maxi = PyGpuArray_DIMS(h)[1]; size_t maxi = PyGpuArray_DIMS(h)[1];
size_t maxj = PyGpuArray_DIMS(o)[1]; size_t maxj = PyGpuArray_DIMS(out)[1];
size_t maxb = PyGpuArray_DIMS(o)[0]; size_t maxb = PyGpuArray_DIMS(out)[0];
ssize_t h_str_0 = PyGpuArray_STRIDES(h)[0]; ssize_t h_str_0 = PyGpuArray_STRIDES(h)[0];
ssize_t h_str_1 = PyGpuArray_STRIDES(h)[1]; ssize_t h_str_1 = PyGpuArray_STRIDES(h)[1];
ssize_t o_str_0 = PyGpuArray_STRIDES(o)[0]; ssize_t o_str_0 = PyGpuArray_STRIDES(out)[0];
ssize_t o_str_1 = PyGpuArray_STRIDES(o)[1]; ssize_t o_str_1 = PyGpuArray_STRIDES(out)[1];
ssize_t W_str_0 = PyGpuArray_STRIDES(W)[0]; ssize_t W_str_0 = PyGpuArray_STRIDES(W)[0];
ssize_t W_str_1 = PyGpuArray_STRIDES(W)[1]; ssize_t W_str_1 = PyGpuArray_STRIDES(W)[1];
...@@ -75,8 +75,8 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W, ...@@ -75,8 +75,8 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
size_t p = i + j * maxi + b * maxi * maxj; size_t p = i + j * maxi + b * maxi * maxj;
inp_list[p] = h->ga.data; inp_list[p] = h->ga.data;
offInp[p] = b * h_str_0 + i * h_str_1 + h->ga.offset; offInp[p] = b * h_str_0 + i * h_str_1 + h->ga.offset;
out_list[p] = o->ga.data; out_list[p] = out->ga.data;
offOut[p] = b * o_str_0 + j * o_str_1 + o->ga.offset; offOut[p] = b * o_str_0 + j * o_str_1 + out->ga.offset;
W_list[p] = W->ga.data; W_list[p] = W->ga.data;
offW[p] = *(DTYPE_INPUT_3 *)PyArray_GETPTR2(inputIdx, b, i) * W_str_0 + offW[p] = *(DTYPE_INPUT_3 *)PyArray_GETPTR2(inputIdx, b, i) * W_str_0 +
*(DTYPE_INPUT_4 *)PyArray_GETPTR2(outputIdx, b, j) * W_str_1 + *(DTYPE_INPUT_4 *)PyArray_GETPTR2(outputIdx, b, j) * W_str_1 +
...@@ -92,22 +92,22 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W, ...@@ -92,22 +92,22 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
lda = PyGpuArray_STRIDES(W)[3] / gpuarray_get_elsize(W->ga.typecode); lda = PyGpuArray_STRIDES(W)[3] / gpuarray_get_elsize(W->ga.typecode);
} }
if (o->ga.typecode == GA_FLOAT) { if (out->ga.typecode == GA_FLOAT) {
err = blas_ops->sgemvBatch(cb_fortran, transA, err = blas_ops->sgemvBatch(cb_fortran, transA,
PyGpuArray_DIMS(o)[2], PyGpuArray_DIMS(out)[2],
PyGpuArray_DIMS(h)[2], 1, PyGpuArray_DIMS(h)[2], 1,
W_list, offW, lda, W_list, offW, lda,
inp_list, offInp, PyGpuArray_STRIDES(h)[2] / gpuarray_get_elsize(h->ga.typecode), inp_list, offInp, PyGpuArray_STRIDES(h)[2] / gpuarray_get_elsize(h->ga.typecode),
1, out_list, offOut, PyGpuArray_STRIDES(o)[2] / gpuarray_get_elsize(o->ga.typecode), 1, out_list, offOut, PyGpuArray_STRIDES(out)[2] / gpuarray_get_elsize(out->ga.typecode),
PyGpuArray_DIMS(o)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(o)[0], 0); PyGpuArray_DIMS(out)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(out)[0], 0);
} else if (o->ga.typecode == GA_DOUBLE) { } else if (out->ga.typecode == GA_DOUBLE) {
err = blas_ops->dgemvBatch(cb_fortran, transA, err = blas_ops->dgemvBatch(cb_fortran, transA,
PyGpuArray_DIMS(o)[2], PyGpuArray_DIMS(out)[2],
PyGpuArray_DIMS(h)[2], 1, PyGpuArray_DIMS(h)[2], 1,
W_list, offW, lda, W_list, offW, lda,
inp_list, offInp, PyGpuArray_STRIDES(h)[2] / gpuarray_get_elsize(h->ga.typecode), inp_list, offInp, PyGpuArray_STRIDES(h)[2] / gpuarray_get_elsize(h->ga.typecode),
1, out_list, offOut, PyGpuArray_STRIDES(o)[2] / gpuarray_get_elsize(o->ga.typecode), 1, out_list, offOut, PyGpuArray_STRIDES(out)[2] / gpuarray_get_elsize(out->ga.typecode),
PyGpuArray_DIMS(o)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(o)[0], 0); PyGpuArray_DIMS(out)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(out)[0], 0);
} else { } else {
err = GA_DEVSUP_ERROR; err = GA_DEVSUP_ERROR;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论