提交 93c40296 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add fallback copy to the inplace C code for all gpuarray blas ops.

上级 82f53497
...@@ -45,8 +45,15 @@ class GpuGemv(BlasOp, Gemv): ...@@ -45,8 +45,15 @@ class GpuGemv(BlasOp, Gemv):
if self.inplace: if self.inplace:
code = """ code = """
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
%(out)s = %(y)s; if (%(y)s->ga.strides[0] <= 0) {
Py_INCREF(%(out)s); %(out)s = pygpu_copy(%(y)s, GA_ANY_ORDER);
if (%(out)s == NULL) {
%(fail)s
}
} else {
%(out)s = %(y)s;
Py_INCREF(%(out)s);
}
""" % vars """ % vars
else: else:
code = """ code = """
...@@ -72,7 +79,7 @@ class GpuGemv(BlasOp, Gemv): ...@@ -72,7 +79,7 @@ class GpuGemv(BlasOp, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
gpugemv_no_inplace = GpuGemv(inplace=False) gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True) gpugemv_inplace = GpuGemv(inplace=True)
...@@ -101,8 +108,15 @@ class GpuGemm(BlasOp, Gemm): ...@@ -101,8 +108,15 @@ class GpuGemm(BlasOp, Gemm):
if self.inplace: if self.inplace:
code = """ code = """
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
%(out)s = %(C)s; if (!GpuArray_ISONESEGMENT(&%(C)s->ga)) {
Py_INCREF(%(out)s); %(out)s = pygpu_copy(%(C)s, GA_ANY_ORDER);
if (%(out)s == NULL) {
%(fail)s
}
} else {
%(out)s = %(C)s;
Py_INCREF(%(out)s);
}
""" % vars """ % vars
else: else:
code = """ code = """
...@@ -128,7 +142,7 @@ class GpuGemm(BlasOp, Gemm): ...@@ -128,7 +142,7 @@ class GpuGemm(BlasOp, Gemm):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
gpugemm_no_inplace = GpuGemm(inplace=False) gpugemm_no_inplace = GpuGemm(inplace=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论