提交 59b58be8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix strides when using dot inside gemv

上级 ebaef5af
...@@ -120,16 +120,12 @@ class GpuGemv(BlasOp): ...@@ -120,16 +120,12 @@ class GpuGemv(BlasOp):
%(out)s->ga.nd = 0; %(out)s->ga.nd = 0;
%(A)s->ga.nd = 1; %(A)s->ga.nd = 1;
%(A)s->ga.dimensions[0] = %(A)s->ga.dimensions[1]; %(A)s->ga.dimensions[0] = %(A)s->ga.dimensions[1];
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
ssize_t a_stride0 = %(A)s->ga.strides[0]; ssize_t a_stride0 = %(A)s->ga.strides[0];
%(A)s->ga.strides[0] = %(A)s->ga.strides[1]; %(A)s->ga.strides[0] = %(A)s->ga.strides[1];
if (pygpu_blas_rdot(%(x)s, %(A)s, %(out)s, 0) == -1) { if (pygpu_blas_rdot(%(x)s, %(A)s, %(out)s, 0) == -1) {
%(fail)s %(fail)s
} }
%(A)s->ga.strides[0] = a_stride0; %(A)s->ga.strides[0] = a_stride0;
} else if (pygpu_blas_rdot(%(x)s, %(A)s, %(out)s, 0) == -1) {
%(fail)s
}
%(out)s->ga.nd = 1; %(out)s->ga.nd = 1;
%(A)s->ga.nd = 2; %(A)s->ga.nd = 2;
%(A)s->ga.dimensions[0] = 1; %(A)s->ga.dimensions[0] = 1;
...@@ -149,7 +145,7 @@ class GpuGemv(BlasOp): ...@@ -149,7 +145,7 @@ class GpuGemv(BlasOp):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) return (8,)
gpugemv_no_inplace = GpuGemv(inplace=False) gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True) gpugemv_inplace = GpuGemv(inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论