提交 21aed283 authored 作者: James Bergstra's avatar James Bergstra 提交者: Amjad Almahairi

enh: faster CGer

上级 6c7a3cc1
......@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{
float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * zdata = (float*)PyArray_DATA(%(A)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j];
tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
}
}
}
......@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{
double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j];
tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
}
}
}
......@@ -154,8 +175,57 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(Z)s = %(A)s;
Py_INCREF(%(Z)s);
}
npy_intp dims[2];
dims[0] = PyArray_DIMS(%(A)s)[0];
dims[1] = PyArray_DIMS(%(A)s)[1];
if ((dims[0] * dims[1]) < 100000)
{
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{
float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
}
}
else
{
int Nz0 = PyArray_DIMS(%(Z)s)[0];
int Nz1 = PyArray_DIMS(%(Z)s)[1];
......@@ -194,11 +264,14 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
//fprintf(stderr, "CGer V1 \\n");
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz0, &Nz1, &alpha,
(double*)x_data, &Sx,
(double*)y_data, &Sy,
(double*)(PyArray_DATA(%(Z)s)), &Sz1);
}
else {
PyErr_SetString(PyExc_NotImplementedError,
......@@ -221,6 +294,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
//fprintf(stderr, "CGer V2 \\n");
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz1, &Nz0, &alpha,
(double*)y_data, &Sy,
......@@ -242,6 +316,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(fail)s
}
}
}
""" % locals()
......@@ -256,7 +331,7 @@ class CGer(BaseBLAS, Ger):
return code
def c_code_cache_version(self):
return (8, blas_header_version())
return (9, blas_header_version())
cger_inplace = CGer(True)
cger_no_inplace = CGer(False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论