提交 dc464643 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5303 from khaotik/pygpu_dot_speedup

GPU gemv->dot speedup for new backend
......@@ -52,6 +52,7 @@ class GpuGemv(BlasOp):
y = as_gpuarray_variable(y, ctx_name)
alpha = as_tensor_variable(alpha).astype('float64')
beta = as_tensor_variable(beta).astype('float64')
assert alpha.ndim == 0
assert beta.ndim == 0
assert A.ndim == 2
......@@ -91,6 +92,8 @@ class GpuGemv(BlasOp):
%(fail)s
}
""" % vars
# in case of possible speed up using blas dot,
# temporary hack A to 1D for vector-vector dot
code += """
if (PyGpuArray_DIM(%(A)s, 1) == 0) {
int code;
......@@ -99,11 +102,32 @@ class GpuGemv(BlasOp):
PyErr_SetString(PyExc_RuntimeError, "Memset failed");
%(fail)s
}
} else if (pygpu_blas_rgemv(cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
%(out)s, 0) == -1) {
} else if ( PyGpuArray_DIM(%(A)s, 0) == 1
&&((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0] == (dtype_%(alpha)s)1.
&&((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0] == (dtype_%(beta)s)0.
) {
%(out)s->ga.nd = 0;
%(A)s->ga.nd = 1;
%(A)s->ga.dimensions[0] = %(A)s->ga.dimensions[1];
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
ssize_t a_stride0 = %(A)s->ga.strides[0];
%(A)s->ga.strides[0] = %(A)s->ga.strides[1];
if (pygpu_blas_rdot(%(x)s, %(A)s, %(y)s, 0) == -1) {
%(fail)s
}
%(A)s->ga.strides[0] = a_stride0;
} else if (pygpu_blas_rdot(%(x)s, %(A)s, %(y)s, 0) == -1) {
%(fail)s
}
%(out)s->ga.nd = 1;
%(A)s->ga.nd = 2;
%(A)s->ga.dimensions[0] = 1;
} else if (
pygpu_blas_rgemv(cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
%(out)s, 0) == -1) {
%(fail)s
}
""" % vars
......@@ -114,7 +138,7 @@ class GpuGemv(BlasOp):
return code
def c_code_cache_version(self):
return (5,)
return (6,)
gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论