提交 983ed2a2 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the last blocksparse problem.

上级 e399939d
......@@ -41,12 +41,12 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
/* Prepare lists for the batch */
size_t maxi = PyGpuArray_DIMS(h)[1];
size_t maxj = PyGpuArray_DIMS(o)[1];
size_t maxb = PyGpuArray_DIMS(o)[0];
size_t maxj = PyGpuArray_DIMS(out)[1];
size_t maxb = PyGpuArray_DIMS(out)[0];
ssize_t h_str_0 = PyGpuArray_STRIDES(h)[0];
ssize_t h_str_1 = PyGpuArray_STRIDES(h)[1];
ssize_t o_str_0 = PyGpuArray_STRIDES(o)[0];
ssize_t o_str_1 = PyGpuArray_STRIDES(o)[1];
ssize_t o_str_0 = PyGpuArray_STRIDES(out)[0];
ssize_t o_str_1 = PyGpuArray_STRIDES(out)[1];
ssize_t W_str_0 = PyGpuArray_STRIDES(W)[0];
ssize_t W_str_1 = PyGpuArray_STRIDES(W)[1];
......@@ -75,8 +75,8 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
size_t p = i + j * maxi + b * maxi * maxj;
inp_list[p] = h->ga.data;
offInp[p] = b * h_str_0 + i * h_str_1 + h->ga.offset;
out_list[p] = o->ga.data;
offOut[p] = b * o_str_0 + j * o_str_1 + o->ga.offset;
out_list[p] = out->ga.data;
offOut[p] = b * o_str_0 + j * o_str_1 + out->ga.offset;
W_list[p] = W->ga.data;
offW[p] = *(DTYPE_INPUT_3 *)PyArray_GETPTR2(inputIdx, b, i) * W_str_0 +
*(DTYPE_INPUT_4 *)PyArray_GETPTR2(outputIdx, b, j) * W_str_1 +
......@@ -92,22 +92,22 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
lda = PyGpuArray_STRIDES(W)[3] / gpuarray_get_elsize(W->ga.typecode);
}
if (o->ga.typecode == GA_FLOAT) {
if (out->ga.typecode == GA_FLOAT) {
err = blas_ops->sgemvBatch(cb_fortran, transA,
PyGpuArray_DIMS(o)[2],
PyGpuArray_DIMS(out)[2],
PyGpuArray_DIMS(h)[2], 1,
W_list, offW, lda,
inp_list, offInp, PyGpuArray_STRIDES(h)[2] / gpuarray_get_elsize(h->ga.typecode),
1, out_list, offOut, PyGpuArray_STRIDES(o)[2] / gpuarray_get_elsize(o->ga.typecode),
PyGpuArray_DIMS(o)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(o)[0], 0);
} else if (o->ga.typecode == GA_DOUBLE) {
1, out_list, offOut, PyGpuArray_STRIDES(out)[2] / gpuarray_get_elsize(out->ga.typecode),
PyGpuArray_DIMS(out)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(out)[0], 0);
} else if (out->ga.typecode == GA_DOUBLE) {
err = blas_ops->dgemvBatch(cb_fortran, transA,
PyGpuArray_DIMS(o)[2],
PyGpuArray_DIMS(out)[2],
PyGpuArray_DIMS(h)[2], 1,
W_list, offW, lda,
inp_list, offInp, PyGpuArray_STRIDES(h)[2] / gpuarray_get_elsize(h->ga.typecode),
1, out_list, offOut, PyGpuArray_STRIDES(o)[2] / gpuarray_get_elsize(o->ga.typecode),
PyGpuArray_DIMS(o)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(o)[0], 0);
1, out_list, offOut, PyGpuArray_STRIDES(out)[2] / gpuarray_get_elsize(out->ga.typecode),
PyGpuArray_DIMS(out)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(out)[0], 0);
} else {
err = GA_DEVSUP_ERROR;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论