提交 7de5da54 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make all blas ops attempt to reuse the previous output buffer.

上级 2d77ade5
...@@ -27,6 +27,26 @@ class BlasOp(HideC): ...@@ -27,6 +27,26 @@ class BlasOp(HideC):
def c_init_code(self): def c_init_code(self):
return ['import_pygpu__blas();'] return ['import_pygpu__blas();']
def c_support_code(self):
return """
PyGpuArray *gpublas_try_copy(PyGpuArray *out, PyGpuArray *y)
if (out &&
GpuArray_CHKFLAGS(&out->ga, GA_CARRAY) &&
theano_size_check(out, PyGpuArray_NDIM(y),
PyGpuArray_DIMS(y),
y->ga.typecode)) {
if (pygpu_move(out, y)) {
Py_XDECREF(%(out)s)
return NULL;
}
} else {
Py_XDECREF(out);
out = pygpu_copy(y, GA_ANY_ORDER);
}
return out;
}
"""
class GpuGemv(BlasOp, Gemv): class GpuGemv(BlasOp, Gemv):
def make_node(self, y, alpha, A, x, beta): def make_node(self, y, alpha, A, x, beta):
...@@ -50,21 +70,20 @@ class GpuGemv(BlasOp, Gemv): ...@@ -50,21 +70,20 @@ class GpuGemv(BlasOp, Gemv):
beta=inp[4], fail=sub['fail'], name=name) beta=inp[4], fail=sub['fail'], name=name)
if self.inplace: if self.inplace:
code = """ code = """
Py_XDECREF(%(out)s);
if (%(y)s->ga.strides[0] <= 0) { if (%(y)s->ga.strides[0] <= 0) {
%(out)s = pygpu_copy(%(y)s, GA_ANY_ORDER); %(out)s = gpublas_try_copy(%(out)s, %(y)s);
if (%(out)s == NULL) { if (%(out)s == NULL) {
%(fail)s %(fail)s
} }
} else { } else {
Py_XDECREF(%(out)s);
%(out)s = %(y)s; %(out)s = %(y)s;
Py_INCREF(%(out)s); Py_INCREF(%(out)s);
} }
""" % vars """ % vars
else: else:
code = """ code = """
Py_XDECREF(%(out)s); %(out)s = gpublas_try_copy(%(out)s, %(y)s);
%(out)s = pygpu_copy(%(y)s, GA_ANY_ORDER);
if (%(out)s == NULL) { if (%(out)s == NULL) {
%(fail)s %(fail)s
} }
...@@ -85,7 +104,7 @@ class GpuGemv(BlasOp, Gemv): ...@@ -85,7 +104,7 @@ class GpuGemv(BlasOp, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
gpugemv_no_inplace = GpuGemv(inplace=False) gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True) gpugemv_inplace = GpuGemv(inplace=True)
...@@ -113,13 +132,13 @@ class GpuGemm(BlasOp, Gemm): ...@@ -113,13 +132,13 @@ class GpuGemm(BlasOp, Gemm):
beta=inp[4], fail=sub['fail'], name=name) beta=inp[4], fail=sub['fail'], name=name)
if self.inplace: if self.inplace:
code = """ code = """
Py_XDECREF(%(out)s);
if (!GpuArray_ISONESEGMENT(&%(C)s->ga)) { if (!GpuArray_ISONESEGMENT(&%(C)s->ga)) {
%(out)s = pygpu_copy(%(C)s, GA_ANY_ORDER); %(out)s = gpublas_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) { if (%(out)s == NULL) {
%(fail)s %(fail)s
} }
} else { } else {
Py_XDECREF(%(out)s);
%(out)s = %(C)s; %(out)s = %(C)s;
Py_INCREF(%(out)s); Py_INCREF(%(out)s);
} }
...@@ -127,7 +146,7 @@ class GpuGemm(BlasOp, Gemm): ...@@ -127,7 +146,7 @@ class GpuGemm(BlasOp, Gemm):
else: else:
code = """ code = """
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
%(out)s = pygpu_copy(%(C)s, GA_ANY_ORDER); %(out)s = gpublas_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) { if (%(out)s == NULL) {
%(fail)s %(fail)s
} }
...@@ -148,7 +167,7 @@ class GpuGemm(BlasOp, Gemm): ...@@ -148,7 +167,7 @@ class GpuGemm(BlasOp, Gemm):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
gpugemm_no_inplace = GpuGemm(inplace=False) gpugemm_no_inplace = GpuGemm(inplace=False)
...@@ -177,21 +196,20 @@ class GpuGer(BlasOp, Ger): ...@@ -177,21 +196,20 @@ class GpuGer(BlasOp, Ger):
fail=sub['fail'], name=name) fail=sub['fail'], name=name)
if self.destructive: if self.destructive:
code = """ code = """
Py_XDECREF(%(out)s);
if (!GpuArray_ISONESEGMENT(&%(A)s->ga)) { if (!GpuArray_ISONESEGMENT(&%(A)s->ga)) {
%(out)s = pygpu_copy(%(A)s, GA_ANY_ORDER); %(out)s = gpublas_try_copy(%(out)s, %(A)s);
if (%(out)s == NULL) { if (%(out)s == NULL) {
%(fail)s %(fail)s
} }
} else { } else {
Py_XDECREF(%(out)s);
%(out)s = %(A)s; %(out)s = %(A)s;
Py_INCREF(%(out)s); Py_INCREF(%(out)s);
} }
""" % vars """ % vars
else: else:
code = """ code = """
Py_XDECREF(%(out)s); %(out)s = gpublas_try_copy(%(out)s, %(A)s);
%(out)s = pygpu_copy(%(A)s, GA_ANY_ORDER);
if (%(out)s == NULL) { if (%(out)s == NULL) {
%(fail)s %(fail)s
} }
...@@ -209,7 +227,7 @@ class GpuGer(BlasOp, Ger): ...@@ -209,7 +227,7 @@ class GpuGer(BlasOp, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
gpuger_no_inplace = GpuGer(destructive=False) gpuger_no_inplace = GpuGer(destructive=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论