提交 daa3b57b authored 作者: James Bergstra's avatar James Bergstra 提交者: Frederic

added CGemv to blas_c

上级 dfd1c614
from theano.tensor.opt import in2out
from theano.gof import Op
from blas import Ger, ger, ger_destructive
from blas import ldflags, blas_header_text
from blas import blas_optdb, optdb, local_optimizer
from blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer
from blas import Ger, ger, ger_destructive
from blas import Gemv, gemv_inplace, gemv_no_inplace
ger_c_code = """
int elemsize ;
class BaseBLAS(object):
def c_libraries(self):
return ldflags()
def c_compile_args(self):
return ldflags(libs=False, flags=True)
def c_lib_dirs(self):
return ldflags(libs=False, libs_dir=True)
def c_header_dirs(self):
return ldflags(libs=False, include_dir=True)
def c_support_code(self):
return blas_header_text()
if (%(A)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(A) != 2"); %(fail)s;}
if (%(x)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 1"); %(fail)s;}
if (%(y)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 1"); %(fail)s;}
if (%(a)s->nd != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 0"); %(fail)s;}
if (%(A)s->descr->type_num != %(x)s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "A vs. x"); %(fail)s; }
if (%(A)s->descr->type_num != %(y)s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "A vs. y"); %(fail)s; }
####### ####### #######
# GER
####### ####### #######
if (%(A)s->dimensions[0] != %(x)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[0] != x.shape[0]"); %(fail)s;}
if (%(A)s->dimensions[1] != %(y)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[1] != y.shape[0]"); %(fail)s;}
def ger_c_code(A, a, x, y, Z, destructive, fail):
return """
if (%(A)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (%(A)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
else {PyErr_SetString(PyExc_NotImplementedError, "complex CGer"); %(fail)s;}
int elemsize ;
if (%(A)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "rank(A) != 2"); %(fail)s;}
if (%(x)s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 1"); %(fail)s;}
if (%(y)s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 1"); %(fail)s;}
if (%(a)s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 0"); %(fail)s;}
// copy A if !self.destructive or A is fully strided
if (!%(destructive)s
|| ((%(A)s->strides[0] != elemsize)
&&
(%(A)s->strides[1] != elemsize)))
{
npy_intp dims[2];
dims[0] = %(A)s->dimensions[0];
dims[1] = %(A)s->dimensions[1];
if (%(A)s->descr->type_num != %(x)s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "A vs. x"); %(fail)s; }
if (%(A)s->descr->type_num != %(y)s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "A vs. y"); %(fail)s; }
if ((NULL == %(Z)s)
|| (%(Z)s->dimensions[0] != %(A)s->dimensions[0])
|| (%(Z)s->dimensions[1] != %(A)s->dimensions[1]))
if (%(A)s->dimensions[0] != %(x)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[0] != x.shape[0]"); %(fail)s;}
if (%(A)s->dimensions[1] != %(y)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[1] != y.shape[0]"); %(fail)s;}
if (%(A)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (%(A)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
else {PyErr_SetString(PyExc_NotImplementedError, "complex CGer"); %(fail)s;}
// copy A if !self.destructive or A is fully strided
if (!%(destructive)s
|| ((%(A)s->strides[0] != elemsize)
&&
(%(A)s->strides[1] != elemsize)))
{
if (%(Z)s) Py_XDECREF(%(Z)s);
%(Z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(A)s);
if(!%(Z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc gemm_no_inplace output");
%(fail)s
npy_intp dims[2];
dims[0] = %(A)s->dimensions[0];
dims[1] = %(A)s->dimensions[1];
if ((NULL == %(Z)s)
|| (%(Z)s->dimensions[0] != %(A)s->dimensions[0])
|| (%(Z)s->dimensions[1] != %(A)s->dimensions[1]))
{
if (%(Z)s) Py_XDECREF(%(Z)s);
%(Z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(A)s);
if(!%(Z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc ger output");
%(fail)s
}
}
}
assert (%(Z)s != %(A)s);
if (%(Z)s->descr->type_num == PyArray_FLOAT)
{
float * zoutdata = (float*)%(Z)s->data;
const float * zdata = (float*)%(A)s->data;
int Ai = %(A)s->strides[0]/sizeof(float);
int Aj = %(A)s->strides[1]/sizeof(float);
int Zi = %(Z)s->strides[0]/sizeof(float);
int Zj = %(Z)s->strides[1]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
assert (%(Z)s != %(A)s);
if (%(Z)s->descr->type_num == PyArray_FLOAT)
{
for (int j = 0; j < dims[1]; ++j)
float * zoutdata = (float*)%(Z)s->data;
const float * zdata = (float*)%(A)s->data;
int Ai = %(A)s->strides[0]/sizeof(float);
int Aj = %(A)s->strides[1]/sizeof(float);
int Zi = %(Z)s->strides[0]/sizeof(float);
int Zj = %(Z)s->strides[1]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j];
}
}
}
}
else if (%(Z)s->descr->type_num == PyArray_DOUBLE)
{
double * zoutdata = (double*) %(Z)s->data;
const double * zdata = (double*)%(A)s->data;
int Ai = %(A)s->strides[0]/sizeof(double);
int Aj = %(A)s->strides[1]/sizeof(double);
int Zi = %(Z)s->strides[0]/sizeof(double);
int Zj = %(Z)s->strides[1]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
else if (%(Z)s->descr->type_num == PyArray_DOUBLE)
{
for (int j = 0; j < dims[1]; ++j)
double * zoutdata = (double*) %(Z)s->data;
const double * zdata = (double*)%(A)s->data;
int Ai = %(A)s->strides[0]/sizeof(double);
int Aj = %(A)s->strides[1]/sizeof(double);
int Zi = %(Z)s->strides[0]/sizeof(double);
int Zj = %(Z)s->strides[1]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
zoutdata[Zi*i*+Zj*j] = zdata[Ai*i+Aj*j];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j];
}
}
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)s
}
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)s
}
}
else
{
//fprintf(stderr, "USING A\\n");
if (%(Z)s != %(A)s)
{
if (%(Z)s) { Py_DECREF(%(Z)s); }
%(Z)s = %(A)s;
Py_INCREF(%(Z)s);
//fprintf(stderr, "USING A\\n");
if (%(Z)s != %(A)s)
{
if (%(Z)s) { Py_DECREF(%(Z)s); }
%(Z)s = %(A)s;
Py_INCREF(%(Z)s);
}
}
}
{
{
int Nz0 = %(Z)s->dimensions[0];
int Nz1 = %(Z)s->dimensions[1];
int Sz0 = %(Z)s->strides[0] / elemsize;
int Sz1 = %(Z)s->strides[1] / elemsize;
int Sx = %(x)s->strides[0] / elemsize;
int Sy = %(y)s->strides[0] / elemsize;
int Nz0 = %(Z)s->dimensions[0];
int Nz1 = %(Z)s->dimensions[1];
int Sz0 = %(Z)s->strides[0] / elemsize;
int Sz1 = %(Z)s->strides[1] / elemsize;
int Sx = %(x)s->strides[0] / elemsize;
int Sy = %(y)s->strides[0] / elemsize;
if (1)
{
if (%(Z)s->strides[0] == elemsize)
{
if (%(Z)s->descr->type_num == PyArray_FLOAT)
if (1)
{
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(a)s*)%(a)s->data)[0];
sger_(&Nz0, &Nz1, &alpha,
(float*)(%(x)s->data), &Sx,
(float*)(%(y)s->data), &Sy,
(float*)(%(Z)s->data), &Sz1);
}
else if (%(Z)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(a)s*)%(a)s->data)[0];
dger_(&Nz0, &Nz1, &alpha,
(double*)(%(x)s->data), &Sx,
(double*)(%(y)s->data), &Sy,
(double*)(%(Z)s->data), &Sz1);
}
else { assert(0); }
}
else if (%(Z)s->strides[1] == elemsize)
{
if (%(Z)s->descr->type_num == PyArray_FLOAT)
if (%(Z)s->strides[0] == elemsize)
{
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(a)s*)(%(a)s->data))[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sger_(&Nz1, &Nz0, &alpha,
(float*)(%(y)s->data), &Sy,
(float*)(%(x)s->data), &Sx,
(float*)(%(Z)s->data), &Sz0);
if (%(Z)s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(a)s*)%(a)s->data)[0];
sger_(&Nz0, &Nz1, &alpha,
(float*)(%(x)s->data), &Sx,
(float*)(%(y)s->data), &Sy,
(float*)(%(Z)s->data), &Sz1);
}
else if (%(Z)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(a)s*)%(a)s->data)[0];
dger_(&Nz0, &Nz1, &alpha,
(double*)(%(x)s->data), &Sx,
(double*)(%(y)s->data), &Sy,
(double*)(%(Z)s->data), &Sz1);
}
else { assert(0); }
}
else if (%(Z)s->descr->type_num == PyArray_DOUBLE)
else if (%(Z)s->strides[1] == elemsize)
{
double alpha = ((dtype_%(a)s*)%(a)s->data)[0];
dger_(&Nz1, &Nz0, &alpha,
(double*)(%(y)s->data), &Sy,
(double*)(%(x)s->data), &Sx,
(double*)(%(Z)s->data), &Sz0);
if (%(Z)s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(a)s*)(%(a)s->data))[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sger_(&Nz1, &Nz0, &alpha,
(float*)(%(y)s->data), &Sy,
(float*)(%(x)s->data), &Sx,
(float*)(%(Z)s->data), &Sz0);
}
else if (%(Z)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(a)s*)%(a)s->data)[0];
dger_(&Nz1, &Nz0, &alpha,
(double*)(%(y)s->data), &Sy,
(double*)(%(x)s->data), &Sx,
(double*)(%(Z)s->data), &Sz0);
}
else { assert(0); }
}
else { assert(0); }
}
}
else { assert(0); }
}
}
"""
class CGer(Ger):
""" % locals()
class CGer(BaseBLAS, Ger):
def c_libraries(self):
return ldflags()
......@@ -181,9 +208,9 @@ class CGer(Ger):
print 'C_CODE'
A, a, x, y = inp
Z, = out
destructive = int(self.destructive)
fail = sub['fail']
code = ger_c_code % locals()
code = ger_c_code(A, a, x, y, Z,
destructive=int(self.destructive),
fail=sub['fail'])
return code
def c_code_cache_version(self):
......@@ -208,16 +235,246 @@ def make_c_ger_destructive(node):
print "inserting destructive C_GER"
return [CGer(True)(*node.inputs)]
use_c_blas = in2out(use_c_ger)
make_c_blas_destructive = in2out(make_c_ger_destructive)
####### ####### #######
# GEMV
####### ####### #######
def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
"""
zz <- beta * aa + alpha * dot(xx, yy)
where xx is a matrix, yy and aa are vectors (ergo zz is vector)
"""
return """
int elemsize ;
float fbeta;
double dbeta;
if (%(aa)s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(aa) != 1"); %(fail)s;}
if (%(xx)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(xx) != 2"); %(fail)s;}
if (%(yy)s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(yy) != 1"); %(fail)s;}
if (%(alpha)s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(alpha) != 0"); %(fail)s;}
if (%(beta)s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(beta) != 0"); %(fail)s;}
if (%(aa)s->descr->type_num != %(xx)s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. xx"); %(fail)s; }
if (%(aa)s->descr->type_num != %(yy)s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; }
if (%(xx)s->dimensions[0] != %(aa)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[0] != x.shape[0]"); %(fail)s;}
if (%(xx)s->dimensions[1] != %(yy)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[1] != y.shape[0]"); %(fail)s;}
if (%(aa)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (%(aa)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
else {PyErr_SetString(PyExc_NotImplementedError, "complex Gemv"); %(fail)s;}
fbeta = dbeta = ((dtype_%(beta)s*)%(beta)s->data)[0];
// copy aa if not destructive
if (!%(destructive)s)
{
if ((NULL == %(zz)s)
|| (%(zz)s->dimensions[0] != %(aa)s->dimensions[0]))
{
if (%(zz)s) Py_XDECREF(%(zz)s);
%(zz)s = (PyArrayObject*)PyArray_SimpleNew(1,
%(aa)s->dimensions, type_num_%(aa)s);
if(!%(zz)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc gemv output");
%(fail)s
}
}
assert (%(zz)s != %(aa)s);
if (dbeta != 0)
{
if (%(zz)s->descr->type_num == PyArray_FLOAT)
{
float * zoutdata = (float*)%(zz)s->data;
const float * zdata = (float*)%(aa)s->data;
int Ai = %(aa)s->strides[0]/sizeof(float);
int Zi = %(zz)s->strides[0]/sizeof(float);
for (int i = 0; i < %(aa)s->dimensions[0]; ++i)
{
zoutdata[Zi*i] = fbeta * zdata[Ai*i];
}
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{
double * zoutdata = (double*) %(zz)s->data;
const double * zdata = (double*)%(aa)s->data;
int Ai = %(aa)s->strides[0]/sizeof(double);
int Zi = %(zz)s->strides[0]/sizeof(double);
for (int i = 0; i < %(aa)s->dimensions[0]; ++i)
{
zoutdata[Zi*i] = dbeta * zdata[Ai*i];
}
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)s
}
fbeta = dbeta = 1.0;
}
}
else
{
//fprintf(stderr, "Gemv working in-place \\n");
if (%(zz)s != %(aa)s)
{
if (%(zz)s) { Py_DECREF(%(zz)s); }
%(zz)s = %(aa)s;
Py_INCREF(%(zz)s);
}
}
{
char TRANS = 'T';
char NOTRANS = 'N';
int Nx0 = %(xx)s->dimensions[0];
int Nx1 = %(xx)s->dimensions[1];
int Sx0 = %(xx)s->strides[0] / elemsize;
int Sx1 = %(xx)s->strides[1] / elemsize;
int Sz = %(zz)s->strides[0] / elemsize;
int Sy = %(yy)s->strides[0] / elemsize;
if (1)
{
if (%(xx)s->strides[0] == elemsize)
{
if (%(xx)s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
sgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha,
(float*)(%(xx)s->data), &Sx1,
(float*)(%(yy)s->data), &Sy,
&fbeta,
(float*)(%(zz)s->data), &Sz);
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
dgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha,
(double*)(%(xx)s->data), &Sx1,
(double*)(%(yy)s->data), &Sy,
&dbeta,
(double*)(%(zz)s->data), &Sz);
}
else
{
assert(0);
}
}
else if (%(xx)s->strides[1] == elemsize)
{
if (%(xx)s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(float*)(%(xx)s->data), &Sx0,
(float*)(%(yy)s->data), &Sy,
&fbeta,
(float*)(%(zz)s->data), &Sz);
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
dgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(double*)(%(xx)s->data), &Sx0,
(double*)(%(yy)s->data), &Sy,
&dbeta,
(double*)(%(zz)s->data), &Sz);
}
else
{
assert(0);
}
}
else
{
// if xx is strided in both directions, then just do the gemv with a
// pair of for loops.
assert (0);
}
} // if(1)
}
""" % locals()
class CGemv(BaseBLAS, Gemv):
def c_code(self, node, name, inp, out, sub):
print 'GEMV C_CODE'
aa, alpha, xx, yy, beta = inp
zz, = out
code = gemv_c_code(
aa, xx, yy, zz, alpha, beta,
destructive=int(self.inplace),
fail=sub['fail'])
return code
def c_code_cache_version(self):
return ()
def make_thunk(*args, **kwargs):
return Op.make_thunk(*args, **kwargs)
@local_optimizer([gemv_inplace, gemv_no_inplace])
def use_c_gemv(node):
if node.op == gemv_no_inplace:
print "inserting C_GEMV"
return [CGemv(inplace=False)(*node.inputs)]
if node.op == gemv_inplace:
print "inserting dstruc C_GEMV"
return [CGemv(inplace=True)(*node.inputs)]
@local_optimizer([CGemv(inplace=False)])
def make_c_gemv_destructive(node):
if node.op == CGemv(inplace=False):
print "inserting destructive C_GER"
return [CGemv(inplace=True)(*node.inputs)]
####### ####### #######
# Optimizers
####### ####### #######
blas_optdb.register('c_blas',
use_c_blas,
90, 'fast_run')
EquilibriumOptimizer([
use_c_ger,
use_c_gemv,
],
max_use_ratio=5),
20, 'fast_run')
print 'BLAS_OPTDB'
print blas_optdb
# this matches the InplaceBlasOpt defined in blas.py
optdb.register('make_c_blas_destructive',
make_c_blas_destructive,
optdb.register('c_blas_destructive',
EquilibriumOptimizer([
make_c_ger_destructive,
make_c_gemv_destructive,
],
max_use_ratio=5),
70.0, 'fast_run', 'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论