提交 5cb343e6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix c code using gpu_dot for gemv

Pass the right variable when not working inplace.
上级 bf9ce78c
...@@ -119,11 +119,11 @@ class GpuGemv(BlasOp): ...@@ -119,11 +119,11 @@ class GpuGemv(BlasOp):
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) { if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
ssize_t a_stride0 = %(A)s->ga.strides[0]; ssize_t a_stride0 = %(A)s->ga.strides[0];
%(A)s->ga.strides[0] = %(A)s->ga.strides[1]; %(A)s->ga.strides[0] = %(A)s->ga.strides[1];
if (pygpu_blas_rdot(%(x)s, %(A)s, %(y)s, 0) == -1) { if (pygpu_blas_rdot(%(x)s, %(A)s, %(out)s, 0) == -1) {
%(fail)s %(fail)s
} }
%(A)s->ga.strides[0] = a_stride0; %(A)s->ga.strides[0] = a_stride0;
} else if (pygpu_blas_rdot(%(x)s, %(A)s, %(y)s, 0) == -1) { } else if (pygpu_blas_rdot(%(x)s, %(A)s, %(out)s, 0) == -1) {
%(fail)s %(fail)s
} }
%(out)s->ga.nd = 1; %(out)s->ga.nd = 1;
...@@ -145,7 +145,7 @@ class GpuGemv(BlasOp): ...@@ -145,7 +145,7 @@ class GpuGemv(BlasOp):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (7,)
gpugemv_no_inplace = GpuGemv(inplace=False) gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True) gpugemv_inplace = GpuGemv(inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论