提交 7de5da54 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make all blas ops attempt to reuse the previous output buffer.

上级 2d77ade5
......@@ -27,6 +27,26 @@ class BlasOp(HideC):
def c_init_code(self):
return ['import_pygpu__blas();']
def c_support_code(self):
return """
PyGpuArray *gpublas_try_copy(PyGpuArray *out, PyGpuArray *y)
if (out &&
GpuArray_CHKFLAGS(&out->ga, GA_CARRAY) &&
theano_size_check(out, PyGpuArray_NDIM(y),
PyGpuArray_DIMS(y),
y->ga.typecode)) {
if (pygpu_move(out, y)) {
Py_XDECREF(%(out)s)
return NULL;
}
} else {
Py_XDECREF(out);
out = pygpu_copy(y, GA_ANY_ORDER);
}
return out;
}
"""
class GpuGemv(BlasOp, Gemv):
def make_node(self, y, alpha, A, x, beta):
......@@ -50,21 +70,20 @@ class GpuGemv(BlasOp, Gemv):
beta=inp[4], fail=sub['fail'], name=name)
if self.inplace:
code = """
Py_XDECREF(%(out)s);
if (%(y)s->ga.strides[0] <= 0) {
%(out)s = pygpu_copy(%(y)s, GA_ANY_ORDER);
%(out)s = gpublas_try_copy(%(out)s, %(y)s);
if (%(out)s == NULL) {
%(fail)s
}
} else {
Py_XDECREF(%(out)s);
%(out)s = %(y)s;
Py_INCREF(%(out)s);
}
""" % vars
else:
code = """
Py_XDECREF(%(out)s);
%(out)s = pygpu_copy(%(y)s, GA_ANY_ORDER);
%(out)s = gpublas_try_copy(%(out)s, %(y)s);
if (%(out)s == NULL) {
%(fail)s
}
......@@ -85,7 +104,7 @@ class GpuGemv(BlasOp, Gemv):
return code
def c_code_cache_version(self):
return (2,)
return (3,)
gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True)
......@@ -113,13 +132,13 @@ class GpuGemm(BlasOp, Gemm):
beta=inp[4], fail=sub['fail'], name=name)
if self.inplace:
code = """
Py_XDECREF(%(out)s);
if (!GpuArray_ISONESEGMENT(&%(C)s->ga)) {
%(out)s = pygpu_copy(%(C)s, GA_ANY_ORDER);
%(out)s = gpublas_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) {
%(fail)s
}
} else {
Py_XDECREF(%(out)s);
%(out)s = %(C)s;
Py_INCREF(%(out)s);
}
......@@ -127,7 +146,7 @@ class GpuGemm(BlasOp, Gemm):
else:
code = """
Py_XDECREF(%(out)s);
%(out)s = pygpu_copy(%(C)s, GA_ANY_ORDER);
%(out)s = gpublas_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) {
%(fail)s
}
......@@ -148,7 +167,7 @@ class GpuGemm(BlasOp, Gemm):
return code
def c_code_cache_version(self):
return (2,)
return (3,)
gpugemm_no_inplace = GpuGemm(inplace=False)
......@@ -177,21 +196,20 @@ class GpuGer(BlasOp, Ger):
fail=sub['fail'], name=name)
if self.destructive:
code = """
Py_XDECREF(%(out)s);
if (!GpuArray_ISONESEGMENT(&%(A)s->ga)) {
%(out)s = pygpu_copy(%(A)s, GA_ANY_ORDER);
%(out)s = gpublas_try_copy(%(out)s, %(A)s);
if (%(out)s == NULL) {
%(fail)s
}
} else {
Py_XDECREF(%(out)s);
%(out)s = %(A)s;
Py_INCREF(%(out)s);
}
""" % vars
else:
code = """
Py_XDECREF(%(out)s);
%(out)s = pygpu_copy(%(A)s, GA_ANY_ORDER);
%(out)s = gpublas_try_copy(%(out)s, %(A)s);
if (%(out)s == NULL) {
%(fail)s
}
......@@ -209,7 +227,7 @@ class GpuGer(BlasOp, Ger):
return code
def c_code_cache_version(self):
return (1,)
return (2,)
gpuger_no_inplace = GpuGer(destructive=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论