提交 dc464643 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5303 from khaotik/pygpu_dot_speedup

GPU gemv->dot speedup for new backend
...@@ -52,6 +52,7 @@ class GpuGemv(BlasOp): ...@@ -52,6 +52,7 @@ class GpuGemv(BlasOp):
y = as_gpuarray_variable(y, ctx_name) y = as_gpuarray_variable(y, ctx_name)
alpha = as_tensor_variable(alpha).astype('float64') alpha = as_tensor_variable(alpha).astype('float64')
beta = as_tensor_variable(beta).astype('float64') beta = as_tensor_variable(beta).astype('float64')
assert alpha.ndim == 0 assert alpha.ndim == 0
assert beta.ndim == 0 assert beta.ndim == 0
assert A.ndim == 2 assert A.ndim == 2
...@@ -91,6 +92,8 @@ class GpuGemv(BlasOp): ...@@ -91,6 +92,8 @@ class GpuGemv(BlasOp):
%(fail)s %(fail)s
} }
""" % vars """ % vars
# in case of possible speed up using blas dot,
# temporary hack A to 1D for vector-vector dot
code += """ code += """
if (PyGpuArray_DIM(%(A)s, 1) == 0) { if (PyGpuArray_DIM(%(A)s, 1) == 0) {
int code; int code;
...@@ -99,11 +102,32 @@ class GpuGemv(BlasOp): ...@@ -99,11 +102,32 @@ class GpuGemv(BlasOp):
PyErr_SetString(PyExc_RuntimeError, "Memset failed"); PyErr_SetString(PyExc_RuntimeError, "Memset failed");
%(fail)s %(fail)s
} }
} else if (pygpu_blas_rgemv(cb_no_trans, } else if ( PyGpuArray_DIM(%(A)s, 0) == 1
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0], &&((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0] == (dtype_%(alpha)s)1.
%(A)s, %(x)s, &&((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0] == (dtype_%(beta)s)0.
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0], ) {
%(out)s, 0) == -1) { %(out)s->ga.nd = 0;
%(A)s->ga.nd = 1;
%(A)s->ga.dimensions[0] = %(A)s->ga.dimensions[1];
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
ssize_t a_stride0 = %(A)s->ga.strides[0];
%(A)s->ga.strides[0] = %(A)s->ga.strides[1];
if (pygpu_blas_rdot(%(x)s, %(A)s, %(y)s, 0) == -1) {
%(fail)s
}
%(A)s->ga.strides[0] = a_stride0;
} else if (pygpu_blas_rdot(%(x)s, %(A)s, %(y)s, 0) == -1) {
%(fail)s
}
%(out)s->ga.nd = 1;
%(A)s->ga.nd = 2;
%(A)s->ga.dimensions[0] = 1;
} else if (
pygpu_blas_rgemv(cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
%(out)s, 0) == -1) {
%(fail)s %(fail)s
} }
""" % vars """ % vars
...@@ -114,7 +138,7 @@ class GpuGemv(BlasOp): ...@@ -114,7 +138,7 @@ class GpuGemv(BlasOp):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (6,)
gpugemv_no_inplace = GpuGemv(inplace=False) gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True) gpugemv_inplace = GpuGemv(inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论