提交 59b58be8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix strides when using dot inside gemv

上级 ebaef5af
......@@ -120,16 +120,12 @@ class GpuGemv(BlasOp):
%(out)s->ga.nd = 0;
%(A)s->ga.nd = 1;
%(A)s->ga.dimensions[0] = %(A)s->ga.dimensions[1];
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
ssize_t a_stride0 = %(A)s->ga.strides[0];
%(A)s->ga.strides[0] = %(A)s->ga.strides[1];
if (pygpu_blas_rdot(%(x)s, %(A)s, %(out)s, 0) == -1) {
%(fail)s
}
%(A)s->ga.strides[0] = a_stride0;
} else if (pygpu_blas_rdot(%(x)s, %(A)s, %(out)s, 0) == -1) {
ssize_t a_stride0 = %(A)s->ga.strides[0];
%(A)s->ga.strides[0] = %(A)s->ga.strides[1];
if (pygpu_blas_rdot(%(x)s, %(A)s, %(out)s, 0) == -1) {
%(fail)s
}
%(A)s->ga.strides[0] = a_stride0;
%(out)s->ga.nd = 1;
%(A)s->ga.nd = 2;
%(A)s->ga.dimensions[0] = 1;
......@@ -149,7 +145,7 @@ class GpuGemv(BlasOp):
return code
def c_code_cache_version(self):
return (7,)
return (8,)
gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论