提交 57d65201 authored 作者: khaotik's avatar khaotik

always runtime check gemv hack

上级 8f64fab5
......@@ -9,7 +9,6 @@ from theano.compile import optdb
from theano.gof import LocalOptGroup
from theano.tensor.basic import as_tensor_variable
from theano.tensor.opt import in2out
from theano.tensor.var import TensorConstant
from .basic_ops import (GpuArrayType, CGpuKernelBase,
as_gpuarray_variable, gpu_contiguous, infer_context_name)
......@@ -54,14 +53,6 @@ class GpuGemv(BlasOp):
alpha = as_tensor_variable(alpha).astype('float64')
beta = as_tensor_variable(beta).astype('float64')
# if alpha==1. and beta==0., we add runtime check
# for possible speed up using vector-vector dot as
# gemv tend to be slower for vector-vector dot
self._use_rt_dot_check = False
if all(map(lambda v: isinstance(v, TensorConstant), [alpha, beta])):
if alpha.value == 1. and beta.value == 0.:
self._use_rt_dot_check = True
assert alpha.ndim == 0
assert beta.ndim == 0
assert A.ndim == 2
......@@ -101,6 +92,8 @@ class GpuGemv(BlasOp):
%(fail)s
}
""" % vars
# in case of possible speed up using blas dot,
# temporary hack A to 1D for vector-vector dot
code += """
if (PyGpuArray_DIM(%(A)s, 1) == 0) {
int code;
......@@ -109,37 +102,32 @@ class GpuGemv(BlasOp):
PyErr_SetString(PyExc_RuntimeError, "Memset failed");
%(fail)s
}
} """ % vars
if self._use_rt_dot_check:
# temporary hack A to 1D for vector-vector dot
code += """
else if (%(A)s->ga.dimensions[0]==1) {
%(out)s->ga.nd = 0;
%(A)s->ga.nd = 1;
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
%(A)s->ga.strides[0] ^= %(A)s->ga.strides[1];
%(A)s->ga.strides[1] ^= %(A)s->ga.strides[0];
%(A)s->ga.strides[0] ^= %(A)s->ga.strides[1];
}
%(A)s->ga.dimensions[0] = %(A)s->ga.dimensions[1];
} else if ( PyGpuArray_DIM(%(A)s, 0) == 1
&&((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0] == (dtype_%(alpha)s)1.
&&((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0] == (dtype_%(beta)s)0.
) {
%(out)s->ga.nd = 0;
%(A)s->ga.nd = 1;
%(A)s->ga.dimensions[0] = %(A)s->ga.dimensions[1];
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
ssize_t a_stride0 = %(A)s->ga.strides[0];
%(A)s->ga.strides[0] = %(A)s->ga.strides[1];
if (pygpu_blas_rdot(%(x)s, %(A)s, %(y)s, 0) == -1) {
%(fail)s
}
%(A)s->ga.dimensions[0] = 1;
if (%(A)s->ga.flags & GA_C_CONTIGUOUS) {
%(A)s->ga.strides[0] ^= %(A)s->ga.strides[1];
%(A)s->ga.strides[1] ^= %(A)s->ga.strides[0];
%(A)s->ga.strides[0] ^= %(A)s->ga.strides[1];
}
%(A)s->ga.nd = 2;
%(out)s->ga.nd = 1;
} """ % vars
code += """
else if (pygpu_blas_rgemv(cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
%(out)s, 0) == -1) {
%(A)s->ga.strides[0] = a_stride0;
} else if (pygpu_blas_rdot(%(x)s, %(A)s, %(y)s, 0) == -1) {
%(fail)s
}
%(out)s->ga.nd = 1;
%(A)s->ga.nd = 2;
%(A)s->ga.dimensions[0] = 1;
} else if (
pygpu_blas_rgemv(cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
%(out)s, 0) == -1) {
%(fail)s
}
""" % vars
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论