提交 a20c564d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix CGemv when vectors have negative strides

Both BLAS and Numpy were trying to be too clever, BLAS by wanting a pointer to the beginning of the memory chunk (even if it will start from the end, and Numpy by providing a pointer to the first element, that is at the end of the memory buffer.
上级 98547629
...@@ -351,11 +351,26 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -351,11 +351,26 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
char NOTRANS = 'N'; char NOTRANS = 'N';
int Nx0 = %(xx)s->dimensions[0]; int Nx0 = %(xx)s->dimensions[0];
int Nx1 = %(xx)s->dimensions[1]; int Nx1 = %(xx)s->dimensions[1];
int Sx0 = %(xx)s->strides[0] / elemsize; /* This formula is needed in the case where xx is actually a row or
int Sx1 = %(xx)s->strides[1] / elemsize; * column matrix, because BLAS sometimes insists that the strides:
* - are not smaller than the number of elements in the array
* - are not 0.
*/
int Sx0 = (Nx0 > 1) ? (%(xx)s->strides[0] / elemsize) : (Nx1 + 1);
int Sx1 = (Nx1 > 1) ? (%(xx)s->strides[1] / elemsize) : (Nx0 + 1);
int Sz = %(zz)s->strides[0] / elemsize; int Sz = %(zz)s->strides[0] / elemsize;
int Sy = %(yy)s->strides[0] / elemsize; int Sy = %(yy)s->strides[0] / elemsize;
dtype_%(yy)s* yy_data = (dtype_%(yy)s*) %(yy)s->data;
dtype_%(zz)s* zz_data = (dtype_%(zz)s*) %(zz)s->data;
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sy < 0)
yy_data += (Nx1 - 1) * Sy;
if (Sz < 0)
zz_data += (Nx0 - 1) * Sz;
if (Nx0 * Nx1) if (Nx0 * Nx1)
{ {
// If xx is neither C- nor F-contiguous, we make a copy. // If xx is neither C- nor F-contiguous, we make a copy.
...@@ -364,8 +379,10 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -364,8 +379,10 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
// gemv on reversed matrix and vectors // gemv on reversed matrix and vectors
// - if the copy is too long, maybe call vector/vector dot on // - if the copy is too long, maybe call vector/vector dot on
// each row instead // each row instead
if ((%(xx)s->strides[0] != elemsize) if ((%(xx)s->strides[0] < 0)
&& (%(xx)s->strides[1] != elemsize)) || (%(xx)s->strides[0] < 0)
|| ((%(xx)s->strides[0] != elemsize)
&& (%(xx)s->strides[1] != elemsize)))
{ {
npy_intp dims[2]; npy_intp dims[2];
dims[0] = Nx0; dims[0] = Nx0;
...@@ -376,8 +393,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -376,8 +393,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
%(fail)s %(fail)s
Py_XDECREF(%(xx)s); Py_XDECREF(%(xx)s);
%(xx)s = xx_copy; %(xx)s = xx_copy;
Sx0 = %(xx)s->strides[0] / elemsize; Sx0 = (Nx0 > 1) ? (%(xx)s->strides[0] / elemsize) : (Nx1 + 1);
Sx1 = %(xx)s->strides[1] / elemsize; Sx1 = (Nx1 > 1) ? (%(xx)s->strides[1] / elemsize) : (Nx0 + 1);
} }
if (%(xx)s->strides[0] == elemsize) if (%(xx)s->strides[0] == elemsize)
...@@ -389,9 +406,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -389,9 +406,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
sgemv_(&NOTRANS, &Nx0, &Nx1, sgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha, &alpha,
(float*)(%(xx)s->data), &Sx1, (float*)(%(xx)s->data), &Sx1,
(float*)(%(yy)s->data), &Sy, (float*)yy_data, &Sy,
&fbeta, &fbeta,
(float*)(%(zz)s->data), &Sz); (float*)zz_data, &Sz);
} }
else if (%(xx)s->descr->type_num == PyArray_DOUBLE) else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{ {
...@@ -399,9 +416,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -399,9 +416,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
dgemv_(&NOTRANS, &Nx0, &Nx1, dgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha, &alpha,
(double*)(%(xx)s->data), &Sx1, (double*)(%(xx)s->data), &Sx1,
(double*)(%(yy)s->data), &Sy, (double*)yy_data, &Sy,
&dbeta, &dbeta,
(double*)(%(zz)s->data), &Sz); (double*)zz_data, &Sz);
} }
else else
{ {
...@@ -420,9 +437,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -420,9 +437,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
sgemv_(&TRANS, &Nx1, &Nx0, sgemv_(&TRANS, &Nx1, &Nx0,
&alpha, &alpha,
(float*)(%(xx)s->data), &Sx0, (float*)(%(xx)s->data), &Sx0,
(float*)(%(yy)s->data), &Sy, (float*)yy_data, &Sy,
&fbeta, &fbeta,
(float*)(%(zz)s->data), &Sz); (float*)zz_data, &Sz);
} }
else if (%(xx)s->descr->type_num == PyArray_DOUBLE) else if (%(xx)s->descr->type_num == PyArray_DOUBLE)
{ {
...@@ -430,9 +447,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -430,9 +447,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
dgemv_(&TRANS, &Nx1, &Nx0, dgemv_(&TRANS, &Nx1, &Nx0,
&alpha, &alpha,
(double*)(%(xx)s->data), &Sx0, (double*)(%(xx)s->data), &Sx0,
(double*)(%(yy)s->data), &Sy, (double*)yy_data, &Sy,
&dbeta, &dbeta,
(double*)(%(zz)s->data), &Sz); (double*)zz_data, &Sz);
} }
else else
{ {
...@@ -475,7 +492,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -475,7 +492,7 @@ class CGemv(BaseBLAS, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) return (8,)
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论