提交 00225377 authored 作者: James Bergstra's avatar James Bergstra 提交者: Frederic

CGemv handles 0-size arguments properly.

上级 b3aa3073
......@@ -332,73 +332,84 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
int Sz = %(zz)s->strides[0] / elemsize;
int Sy = %(yy)s->strides[0] / elemsize;
if (1)
{
if (%(xx)s->strides[0] == elemsize)
if (Nx0 * Nx1)
{
if (%(xx)s->descr->type_num == PyArray_FLOAT)
if (%(xx)s->strides[0] == elemsize)
{
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
sgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha,
(float*)(%(xx)s->data), &Sx1,
(float*)(%(yy)s->data), &Sy,
&fbeta,
(float*)(%(zz)s->data), &Sz);
if (%(xx)s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
sgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha,
(float*)(%(xx)s->data), &Sx1,
(float*)(%(yy)s->data), &Sy,
&fbeta,
(float*)(%(zz)s->data), &Sz);
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
dgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha,
(double*)(%(xx)s->data), &Sx1,
(double*)(%(yy)s->data), &Sy,
&dbeta,
(double*)(%(zz)s->data), &Sz);
}
else
{
assert(0);
}
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
else if (%(xx)s->strides[1] == elemsize)
{
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
dgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha,
(double*)(%(xx)s->data), &Sx1,
(double*)(%(yy)s->data), &Sy,
&dbeta,
(double*)(%(zz)s->data), &Sz);
if (%(xx)s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(float*)(%(xx)s->data), &Sx0,
(float*)(%(yy)s->data), &Sy,
&fbeta,
(float*)(%(zz)s->data), &Sz);
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
dgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(double*)(%(xx)s->data), &Sx0,
(double*)(%(yy)s->data), &Sy,
&dbeta,
(double*)(%(zz)s->data), &Sz);
}
else
{
assert(0);
}
}
else
{
assert(0);
// if xx is strided in both directions, then just do the gemv with a
// pair of for loops.
assert (0);
}
}
else if (%(xx)s->strides[1] == elemsize)
else if (dbeta != 1.0)
{
if (%(xx)s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(float*)(%(xx)s->data), &Sx0,
(float*)(%(yy)s->data), &Sy,
&fbeta,
(float*)(%(zz)s->data), &Sz);
}
else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
dgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(double*)(%(xx)s->data), &Sx0,
(double*)(%(yy)s->data), &Sy,
&dbeta,
(double*)(%(zz)s->data), &Sz);
}
else
// the matrix has at least one dim of length 0
// so we do this loop, which either iterates over 0 elements
// or else it does the right thing for length-0 x.
dtype_%(zz)s * zptr = (dtype_%(zz)s*)(%(zz)s->data);
for (int i = 0; i < Nx0; ++i)
{
assert(0);
zptr[i * Sz] *= dbeta;
}
}
else
{
// if xx is strided in both directions, then just do the gemv with a
// pair of for loops.
assert (0);
}
} // if(1)
}
""" % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论