提交 20d65357 authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #1202 from hunse/master

CGemv now uses vector-vector dot product when appropriate
...@@ -492,27 +492,49 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -492,27 +492,49 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
{ {
if (PyArray_DESCR(%(xx)s)->type_num == NPY_FLOAT) if (PyArray_DESCR(%(xx)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "B %%i %%i %%i %%i\\n",
// Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy); // Check for vector-vector dot (Nx0 == 1). The code may work
sgemv_(&TRANS, &Nx1, &Nx0, // for Sx1 != 1 as well, but has not been tested for this case,
&alpha, // so Sx1 == 1 is required for safety.
(float*)(PyArray_DATA(%(xx)s)), &Sx0, if (Nx0 == 1 && Sx1 == 1)
(float*)yy_data, &Sy, {
&fbeta, zz_data[0] = fbeta*zz_data[0] + alpha*sdot_(&Nx1,
(float*)zz_data, &Sz); (float*)(PyArray_DATA(%(xx)s)), &Sx1,
(float*)yy_data, &Sy);
}
else
{
sgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(float*)(PyArray_DATA(%(xx)s)), &Sx0,
(float*)yy_data, &Sy,
&fbeta,
(float*)zz_data, &Sz);
}
} }
else if (PyArray_DESCR(%(xx)s)->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(xx)s)->type_num == NPY_DOUBLE)
{ {
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&TRANS, &Nx1, &Nx0,
&alpha, // Check for vector-vector dot (Nx0 == 1). The code may work
(double*)(PyArray_DATA(%(xx)s)), &Sx0, // for Sx1 != 1 as well, but has not been tested for this case,
(double*)yy_data, &Sy, // so Sx1 == 1 is required for safety.
&dbeta, if (Nx0 == 1 && Sx1 == 1)
(double*)zz_data, &Sz); {
zz_data[0] = dbeta*zz_data[0] + alpha*ddot_(&Nx1,
(double*)(PyArray_DATA(%(xx)s)), &Sx1,
(double*)yy_data, &Sy);
}
else
{
dgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(double*)(PyArray_DATA(%(xx)s)), &Sx0,
(double*)yy_data, &Sy,
&dbeta,
(double*)zz_data, &Sz);
}
} }
else else
{ {
...@@ -556,7 +578,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -556,7 +578,7 @@ class CGemv(BaseBLAS, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (9,) return (10,)
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论