提交 00225377 authored 作者: James Bergstra's avatar James Bergstra 提交者: Frederic

CGemv handles 0-size arguments properly.

上级 b3aa3073
...@@ -332,73 +332,84 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -332,73 +332,84 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
int Sz = %(zz)s->strides[0] / elemsize; int Sz = %(zz)s->strides[0] / elemsize;
int Sy = %(yy)s->strides[0] / elemsize; int Sy = %(yy)s->strides[0] / elemsize;
if (1) if (Nx0 * Nx1)
{
if (%(xx)s->strides[0] == elemsize)
{ {
if (%(xx)s->descr->type_num == PyArray_FLOAT) if (%(xx)s->strides[0] == elemsize)
{ {
//fprintf(stderr, "A\\n"); if (%(xx)s->descr->type_num == PyArray_FLOAT)
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0]; {
sgemv_(&NOTRANS, &Nx0, &Nx1, //fprintf(stderr, "A\\n");
&alpha, float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
(float*)(%(xx)s->data), &Sx1, sgemv_(&NOTRANS, &Nx0, &Nx1,
(float*)(%(yy)s->data), &Sy, &alpha,
&fbeta, (float*)(%(xx)s->data), &Sx1,
(float*)(%(zz)s->data), &Sz); (float*)(%(yy)s->data), &Sy,
&fbeta,
(float*)(%(zz)s->data), &Sz);
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
dgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha,
(double*)(%(xx)s->data), &Sx1,
(double*)(%(yy)s->data), &Sy,
&dbeta,
(double*)(%(zz)s->data), &Sz);
}
else
{
assert(0);
}
} }
else if (%(xx)s->descr->type_num == PyArray_DOUBLE) else if (%(xx)s->strides[1] == elemsize)
{ {
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0]; if (%(xx)s->descr->type_num == PyArray_FLOAT)
dgemv_(&NOTRANS, &Nx0, &Nx1, {
&alpha, //fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
(double*)(%(xx)s->data), &Sx1, float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
(double*)(%(yy)s->data), &Sy, //fprintf(stderr, "alpha=%%f\\n", alpha);
&dbeta, //fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
(double*)(%(zz)s->data), &Sz); sgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(float*)(%(xx)s->data), &Sx0,
(float*)(%(yy)s->data), &Sy,
&fbeta,
(float*)(%(zz)s->data), &Sz);
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
dgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(double*)(%(xx)s->data), &Sx0,
(double*)(%(yy)s->data), &Sy,
&dbeta,
(double*)(%(zz)s->data), &Sz);
}
else
{
assert(0);
}
} }
else else
{ {
assert(0); // if xx is strided in both directions, then just do the gemv with a
// pair of for loops.
assert (0);
} }
} }
else if (%(xx)s->strides[1] == elemsize) else if (dbeta != 1.0)
{ {
if (%(xx)s->descr->type_num == PyArray_FLOAT) // the matrix has at least one dim of length 0
{ // so we do this loop, which either iterates over 0 elements
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1); // or else it does the right thing for length-0 x.
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0]; dtype_%(zz)s * zptr = (dtype_%(zz)s*)(%(zz)s->data);
//fprintf(stderr, "alpha=%%f\\n", alpha); for (int i = 0; i < Nx0; ++i)
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(float*)(%(xx)s->data), &Sx0,
(float*)(%(yy)s->data), &Sy,
&fbeta,
(float*)(%(zz)s->data), &Sz);
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
dgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(double*)(%(xx)s->data), &Sx0,
(double*)(%(yy)s->data), &Sy,
&dbeta,
(double*)(%(zz)s->data), &Sz);
}
else
{ {
assert(0); zptr[i * Sz] *= dbeta;
} }
} }
else
{
// if xx is strided in both directions, then just do the gemv with a
// pair of for loops.
assert (0);
}
} // if(1)
} }
""" % locals() """ % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论