提交 21aed283 authored 作者: James Bergstra's avatar James Bergstra 提交者: Amjad Almahairi

enh: faster CGer

上级 6c7a3cc1
...@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
float * zoutdata = (float*)PyArray_DATA(%(Z)s); float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * zdata = (float*)PyArray_DATA(%(A)s); const float * zdata = (float*)PyArray_DATA(%(A)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float); int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float); int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float); int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float); int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i) for (int i = 0; i < dims[0]; ++i)
{ {
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j) for (int j = 0; j < dims[1]; ++j)
{ {
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j]; tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
} }
} }
} }
...@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
double * zoutdata = (double*) PyArray_DATA(%(Z)s); double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s); const double * zdata = (double*)PyArray_DATA(%(A)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double); int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double); int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double); int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double); int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i) for (int i = 0; i < dims[0]; ++i)
{ {
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j) for (int j = 0; j < dims[1]; ++j)
{ {
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j]; tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
} }
} }
} }
...@@ -154,93 +175,147 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -154,93 +175,147 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(Z)s = %(A)s; %(Z)s = %(A)s;
Py_INCREF(%(Z)s); Py_INCREF(%(Z)s);
} }
} npy_intp dims[2];
dims[0] = PyArray_DIMS(%(A)s)[0];
{ dims[1] = PyArray_DIMS(%(A)s)[1];
int Nz0 = PyArray_DIMS(%(Z)s)[0]; if ((dims[0] * dims[1]) < 100000)
int Nz1 = PyArray_DIMS(%(Z)s)[1];
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
int Sy = PyArray_STRIDES(%(y)s)[0] / elemsize;
/* create appropriate strides for Z, if it is a row or column matrix.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int Sz0 = (Nz0 > 1) ? (PyArray_STRIDES(%(Z)s)[0] / elemsize) : (Nz1 + 1);
int Sz1 = (Nz1 > 1) ? (PyArray_STRIDES(%(Z)s)[1] / elemsize) : (Nz0 + 1);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(y)s* y_data = (dtype_%(y)s*) PyArray_DATA(%(y)s);
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (Nz0 - 1) * Sx;
if (Sy < 0)
y_data += (Nz1 - 1) * Sy;
if (PyArray_STRIDES(%(Z)s)[0] == elemsize)
{ {
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT) if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "A\\n"); float * zoutdata = (float*)PyArray_DATA(%(Z)s);
float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; const float * xdata = (float*)PyArray_DATA(%(x)s);
sger_(&Nz0, &Nz1, &alpha, const float * ydata = (float*)PyArray_DATA(%(y)s);
(float*)x_data, &Sx, const float * adata = (float*)PyArray_DATA(%(a)s);
(float*)y_data, &Sy, const float alpha = adata[0];
(float*)(PyArray_DATA(%(Z)s)), &Sz1); float tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
} }
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{ {
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; double * zoutdata = (double*) PyArray_DATA(%(Z)s);
dger_(&Nz0, &Nz1, &alpha, const double * zdata = (double*)PyArray_DATA(%(A)s);
(double*)x_data, &Sx, const double * xdata = (double*)PyArray_DATA(%(x)s);
(double*)y_data, &Sy, const double * ydata = (double*)PyArray_DATA(%(y)s);
(double*)(PyArray_DATA(%(Z)s)), &Sz1); const double * adata = (double*)PyArray_DATA(%(a)s);
} const double alpha = adata[0];
else { double tmp, axi;
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double"); int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
%(fail)s int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
} }
} }
else if (PyArray_STRIDES(%(Z)s)[1] == elemsize) else
{ {
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT) int Nz0 = PyArray_DIMS(%(Z)s)[0];
int Nz1 = PyArray_DIMS(%(Z)s)[1];
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
int Sy = PyArray_STRIDES(%(y)s)[0] / elemsize;
/* create appropriate strides for Z, if it is a row or column matrix.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int Sz0 = (Nz0 > 1) ? (PyArray_STRIDES(%(Z)s)[0] / elemsize) : (Nz1 + 1);
int Sz1 = (Nz1 > 1) ? (PyArray_STRIDES(%(Z)s)[1] / elemsize) : (Nz0 + 1);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(y)s* y_data = (dtype_%(y)s*) PyArray_DATA(%(y)s);
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (Nz0 - 1) * Sx;
if (Sy < 0)
y_data += (Nz1 - 1) * Sy;
if (PyArray_STRIDES(%(Z)s)[0] == elemsize)
{ {
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1); if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0]; {
//fprintf(stderr, "alpha=%%f\\n", alpha); //fprintf(stderr, "A\\n");
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy); float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
sger_(&Nz1, &Nz0, &alpha, sger_(&Nz0, &Nz1, &alpha,
(float*)y_data, &Sy, (float*)x_data, &Sx,
(float*)x_data, &Sx, (float*)y_data, &Sy,
(float*)(PyArray_DATA(%(Z)s)), &Sz0); (float*)(PyArray_DATA(%(Z)s)), &Sz1);
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
//fprintf(stderr, "CGer V1 \\n");
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz0, &Nz1, &alpha,
(double*)x_data, &Sx,
(double*)y_data, &Sy,
(double*)(PyArray_DATA(%(Z)s)), &Sz1);
}
else {
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s
}
} }
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE) else if (PyArray_STRIDES(%(Z)s)[1] == elemsize)
{ {
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
dger_(&Nz1, &Nz0, &alpha, {
(double*)y_data, &Sy, //fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
(double*)x_data, &Sx, float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0];
(double*)(PyArray_DATA(%(Z)s)), &Sz0); //fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sger_(&Nz1, &Nz0, &alpha,
(float*)y_data, &Sy,
(float*)x_data, &Sx,
(float*)(PyArray_DATA(%(Z)s)), &Sz0);
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
//fprintf(stderr, "CGer V2 \\n");
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz1, &Nz0, &alpha,
(double*)y_data, &Sy,
(double*)x_data, &Sx,
(double*)(PyArray_DATA(%(Z)s)), &Sz0);
}
else
{
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s
}
} }
else else
{ {
PyErr_SetString(PyExc_NotImplementedError, PyErr_SetString(PyExc_AssertionError,
"not float nor double"); "A is a double-strided matrix, and should have been copied "
"into a memory-contiguous one.");
%(fail)s %(fail)s
} }
} }
else
{
PyErr_SetString(PyExc_AssertionError,
"A is a double-strided matrix, and should have been copied "
"into a memory-contiguous one.");
%(fail)s
}
} }
""" % locals() """ % locals()
...@@ -256,7 +331,7 @@ class CGer(BaseBLAS, Ger): ...@@ -256,7 +331,7 @@ class CGer(BaseBLAS, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (8, blas_header_version()) return (9, blas_header_version())
cger_inplace = CGer(True) cger_inplace = CGer(True)
cger_no_inplace = CGer(False) cger_no_inplace = CGer(False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论