提交 7646cf27 authored 作者: khaotik's avatar khaotik

new backend gemv speedhack using vector-vector dot

上级 cdac0c69
......@@ -9,6 +9,7 @@ from theano.compile import optdb
from theano.gof import LocalOptGroup
from theano.tensor.basic import as_tensor_variable
from theano.tensor.opt import in2out
from theano.tensor.var import TensorConstant
from .basic_ops import (GpuArrayType, CGpuKernelBase,
as_gpuarray_variable, gpu_contiguous, infer_context_name)
......@@ -52,6 +53,15 @@ class GpuGemv(BlasOp):
y = as_gpuarray_variable(y, ctx_name)
alpha = as_tensor_variable(alpha).astype('float64')
beta = as_tensor_variable(beta).astype('float64')
# if alpha==1. and beta==0., we add runtime check
# for possible speed up using vector-vector dot as
# gemv tend to be slower for vector-vector dot
self._use_rt_dot_check = False
if all(map(lambda v: isinstance(v, TensorConstant), [alpha, beta])):
if alpha.value == 1. and beta.value == 0.:
self._use_rt_dot_check = True
assert alpha.ndim == 0
assert beta.ndim == 0
assert A.ndim == 2
......@@ -65,6 +75,7 @@ class GpuGemv(BlasOp):
inplace = self.inplace
if inplace and y.strides[0] < 0:
inplace = False
print(alpha, beta)
out_storage[0][0] = blas.gemv(alpha, A, x, beta, y,
overwrite_y=inplace)
......@@ -99,7 +110,33 @@ class GpuGemv(BlasOp):
PyErr_SetString(PyExc_RuntimeError, "Memset failed");
%(fail)s
}
} else if (pygpu_blas_rgemv(cb_no_trans,
} """ % vars
if self._use_rt_dot_check:
# temporary hack A to 1D for vector-vector dot
code += """
else if (%(A)s->ga.dimensions[0]==1) {
%(out)s->ga.nd = 0;
%(A)s->ga.nd = 1;
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
%(A)s->ga.strides[0] ^= %(A)s->ga.strides[1];
%(A)s->ga.strides[1] ^= %(A)s->ga.strides[0];
%(A)s->ga.strides[0] ^= %(A)s->ga.strides[1];
}
%(A)s->ga.dimensions[0] = %(A)s->ga.dimensions[1];
if (pygpu_blas_rdot(%(x)s, %(A)s, %(y)s, 0) == -1) {
%(fail)s
}
%(A)s->ga.dimensions[0] = 1;
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
%(A)s->ga.strides[0] ^= %(A)s->ga.strides[1];
%(A)s->ga.strides[1] ^= %(A)s->ga.strides[0];
%(A)s->ga.strides[0] ^= %(A)s->ga.strides[1];
}
%(A)s->ga.nd = 2;
%(out)s->ga.nd = 1;
} """ % vars
code += """
else if (pygpu_blas_rgemv(cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论