提交 21aed283 authored 作者: James Bergstra's avatar James Bergstra 提交者: Amjad Almahairi

enh: faster CGer

上级 6c7a3cc1
...@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
float * zoutdata = (float*)PyArray_DATA(%(Z)s); float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * zdata = (float*)PyArray_DATA(%(A)s); const float * zdata = (float*)PyArray_DATA(%(A)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float); int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float); int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float); int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float); int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i) for (int i = 0; i < dims[0]; ++i)
{ {
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j) for (int j = 0; j < dims[1]; ++j)
{ {
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j]; tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
} }
} }
} }
...@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
double * zoutdata = (double*) PyArray_DATA(%(Z)s); double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s); const double * zdata = (double*)PyArray_DATA(%(A)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double); int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double); int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double); int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double); int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i) for (int i = 0; i < dims[0]; ++i)
{ {
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j) for (int j = 0; j < dims[1]; ++j)
{ {
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j]; tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
} }
} }
} }
...@@ -154,8 +175,57 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -154,8 +175,57 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(Z)s = %(A)s; %(Z)s = %(A)s;
Py_INCREF(%(Z)s); Py_INCREF(%(Z)s);
} }
npy_intp dims[2];
dims[0] = PyArray_DIMS(%(A)s)[0];
dims[1] = PyArray_DIMS(%(A)s)[1];
if ((dims[0] * dims[1]) < 100000)
{
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{
float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
} }
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
}
}
else
{ {
int Nz0 = PyArray_DIMS(%(Z)s)[0]; int Nz0 = PyArray_DIMS(%(Z)s)[0];
int Nz1 = PyArray_DIMS(%(Z)s)[1]; int Nz1 = PyArray_DIMS(%(Z)s)[1];
...@@ -194,11 +264,14 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -194,11 +264,14 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{ {
//fprintf(stderr, "CGer V1 \\n");
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz0, &Nz1, &alpha, dger_(&Nz0, &Nz1, &alpha,
(double*)x_data, &Sx, (double*)x_data, &Sx,
(double*)y_data, &Sy, (double*)y_data, &Sy,
(double*)(PyArray_DATA(%(Z)s)), &Sz1); (double*)(PyArray_DATA(%(Z)s)), &Sz1);
} }
else { else {
PyErr_SetString(PyExc_NotImplementedError, PyErr_SetString(PyExc_NotImplementedError,
...@@ -221,6 +294,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -221,6 +294,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{ {
//fprintf(stderr, "CGer V2 \\n");
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz1, &Nz0, &alpha, dger_(&Nz1, &Nz0, &alpha,
(double*)y_data, &Sy, (double*)y_data, &Sy,
...@@ -242,6 +316,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -242,6 +316,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(fail)s %(fail)s
} }
} }
}
""" % locals() """ % locals()
...@@ -256,7 +331,7 @@ class CGer(BaseBLAS, Ger): ...@@ -256,7 +331,7 @@ class CGer(BaseBLAS, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (8, blas_header_version()) return (9, blas_header_version())
cger_inplace = CGer(True) cger_inplace = CGer(True)
cger_no_inplace = CGer(False) cger_no_inplace = CGer(False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论