提交 21aed283 authored 作者: James Bergstra's avatar James Bergstra 提交者: Amjad Almahairi

enh: faster CGer

上级 6c7a3cc1
......@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{
float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * zdata = (float*)PyArray_DATA(%(A)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j];
tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
}
}
}
......@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{
double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j];
tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
}
}
}
......@@ -154,93 +175,147 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(Z)s = %(A)s;
Py_INCREF(%(Z)s);
}
}
{
int Nz0 = PyArray_DIMS(%(Z)s)[0];
int Nz1 = PyArray_DIMS(%(Z)s)[1];
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
int Sy = PyArray_STRIDES(%(y)s)[0] / elemsize;
/* create appropriate strides for Z, if it is a row or column matrix.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int Sz0 = (Nz0 > 1) ? (PyArray_STRIDES(%(Z)s)[0] / elemsize) : (Nz1 + 1);
int Sz1 = (Nz1 > 1) ? (PyArray_STRIDES(%(Z)s)[1] / elemsize) : (Nz0 + 1);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(y)s* y_data = (dtype_%(y)s*) PyArray_DATA(%(y)s);
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (Nz0 - 1) * Sx;
if (Sy < 0)
y_data += (Nz1 - 1) * Sy;
if (PyArray_STRIDES(%(Z)s)[0] == elemsize)
npy_intp dims[2];
dims[0] = PyArray_DIMS(%(A)s)[0];
dims[1] = PyArray_DIMS(%(A)s)[1];
if ((dims[0] * dims[1]) < 100000)
{
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
sger_(&Nz0, &Nz1, &alpha,
(float*)x_data, &Sx,
(float*)y_data, &Sy,
(float*)(PyArray_DATA(%(Z)s)), &Sz1);
float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz0, &Nz1, &alpha,
(double*)x_data, &Sx,
(double*)y_data, &Sy,
(double*)(PyArray_DATA(%(Z)s)), &Sz1);
}
else {
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s
double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
}
}
else if (PyArray_STRIDES(%(Z)s)[1] == elemsize)
else
{
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
int Nz0 = PyArray_DIMS(%(Z)s)[0];
int Nz1 = PyArray_DIMS(%(Z)s)[1];
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
int Sy = PyArray_STRIDES(%(y)s)[0] / elemsize;
/* create appropriate strides for Z, if it is a row or column matrix.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int Sz0 = (Nz0 > 1) ? (PyArray_STRIDES(%(Z)s)[0] / elemsize) : (Nz1 + 1);
int Sz1 = (Nz1 > 1) ? (PyArray_STRIDES(%(Z)s)[1] / elemsize) : (Nz0 + 1);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(y)s* y_data = (dtype_%(y)s*) PyArray_DATA(%(y)s);
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (Nz0 - 1) * Sx;
if (Sy < 0)
y_data += (Nz1 - 1) * Sy;
if (PyArray_STRIDES(%(Z)s)[0] == elemsize)
{
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sger_(&Nz1, &Nz0, &alpha,
(float*)y_data, &Sy,
(float*)x_data, &Sx,
(float*)(PyArray_DATA(%(Z)s)), &Sz0);
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
sger_(&Nz0, &Nz1, &alpha,
(float*)x_data, &Sx,
(float*)y_data, &Sy,
(float*)(PyArray_DATA(%(Z)s)), &Sz1);
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
//fprintf(stderr, "CGer V1 \\n");
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz0, &Nz1, &alpha,
(double*)x_data, &Sx,
(double*)y_data, &Sy,
(double*)(PyArray_DATA(%(Z)s)), &Sz1);
}
else {
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s
}
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
else if (PyArray_STRIDES(%(Z)s)[1] == elemsize)
{
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz1, &Nz0, &alpha,
(double*)y_data, &Sy,
(double*)x_data, &Sx,
(double*)(PyArray_DATA(%(Z)s)), &Sz0);
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sger_(&Nz1, &Nz0, &alpha,
(float*)y_data, &Sy,
(float*)x_data, &Sx,
(float*)(PyArray_DATA(%(Z)s)), &Sz0);
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
//fprintf(stderr, "CGer V2 \\n");
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz1, &Nz0, &alpha,
(double*)y_data, &Sy,
(double*)x_data, &Sx,
(double*)(PyArray_DATA(%(Z)s)), &Sz0);
}
else
{
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s
}
}
else
{
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
PyErr_SetString(PyExc_AssertionError,
"A is a double-strided matrix, and should have been copied "
"into a memory-contiguous one.");
%(fail)s
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"A is a double-strided matrix, and should have been copied "
"into a memory-contiguous one.");
%(fail)s
}
}
""" % locals()
......@@ -256,7 +331,7 @@ class CGer(BaseBLAS, Ger):
return code
def c_code_cache_version(self):
return (8, blas_header_version())
return (9, blas_header_version())
cger_inplace = CGer(True)
cger_no_inplace = CGer(False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论