提交 15e577fd authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix negative strides problems in CGer

上级 a20c564d
...@@ -58,9 +58,10 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -58,9 +58,10 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
// copy A if !self.destructive or A is fully strided // copy A if !self.destructive or A is fully strided
if (!%(destructive)s if (!%(destructive)s
|| (%(A)s->strides[0] < 0)
|| (%(A)s->strides[1] < 0)
|| ((%(A)s->strides[0] != elemsize) || ((%(A)s->strides[0] != elemsize)
&& && (%(A)s->strides[1] != elemsize)))
(%(A)s->strides[1] != elemsize)))
{ {
npy_intp dims[2]; npy_intp dims[2];
dims[0] = %(A)s->dimensions[0]; dims[0] = %(A)s->dimensions[0];
...@@ -68,7 +69,11 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -68,7 +69,11 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
if ((NULL == %(Z)s) if ((NULL == %(Z)s)
|| (%(Z)s->dimensions[0] != %(A)s->dimensions[0]) || (%(Z)s->dimensions[0] != %(A)s->dimensions[0])
|| (%(Z)s->dimensions[1] != %(A)s->dimensions[1])) || (%(Z)s->dimensions[1] != %(A)s->dimensions[1])
|| (%(Z)s->strides[0] < 0)
|| (%(Z)s->strides[1] < 0)
|| ((%(Z)s->strides[0] != elemsize)
&& (%(Z)s->strides[1] != elemsize)))
{ {
if (%(Z)s) Py_XDECREF(%(Z)s); if (%(Z)s) Py_XDECREF(%(Z)s);
%(Z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(A)s)); %(Z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(A)s));
...@@ -146,8 +151,17 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -146,8 +151,17 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
int Sz0 = (Nz0 > 1) ? (%(Z)s->strides[0] / elemsize) : (Nz1 + 1); int Sz0 = (Nz0 > 1) ? (%(Z)s->strides[0] / elemsize) : (Nz1 + 1);
int Sz1 = (Nz1 > 1) ? (%(Z)s->strides[1] / elemsize) : (Nz0 + 1); int Sz1 = (Nz1 > 1) ? (%(Z)s->strides[1] / elemsize) : (Nz0 + 1);
if (1) printf("Sz: %%i, %%i\\n", Sz0, Sz1);
{ dtype_%(x)s* x_data = (dtype_%(x)s*) %(x)s->data;
dtype_%(y)s* y_data = (dtype_%(y)s*) %(y)s->data;
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (Nz0 - 1) * Sx;
if (Sy < 0)
y_data += (Nz1 - 1) * Sy;
if (%(Z)s->strides[0] == elemsize) if (%(Z)s->strides[0] == elemsize)
{ {
if (%(Z)s->descr->type_num == PyArray_FLOAT) if (%(Z)s->descr->type_num == PyArray_FLOAT)
...@@ -155,16 +169,16 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -155,16 +169,16 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
//fprintf(stderr, "A\\n"); //fprintf(stderr, "A\\n");
float alpha = ((dtype_%(a)s*)%(a)s->data)[0]; float alpha = ((dtype_%(a)s*)%(a)s->data)[0];
sger_(&Nz0, &Nz1, &alpha, sger_(&Nz0, &Nz1, &alpha,
(float*)(%(x)s->data), &Sx, (float*)x_data, &Sx,
(float*)(%(y)s->data), &Sy, (float*)y_data, &Sy,
(float*)(%(Z)s->data), &Sz1); (float*)(%(Z)s->data), &Sz1);
} }
else if (%(Z)s->descr->type_num == PyArray_DOUBLE) else if (%(Z)s->descr->type_num == PyArray_DOUBLE)
{ {
double alpha = ((dtype_%(a)s*)%(a)s->data)[0]; double alpha = ((dtype_%(a)s*)%(a)s->data)[0];
dger_(&Nz0, &Nz1, &alpha, dger_(&Nz0, &Nz1, &alpha,
(double*)(%(x)s->data), &Sx, (double*)x_data, &Sx,
(double*)(%(y)s->data), &Sy, (double*)y_data, &Sy,
(double*)(%(Z)s->data), &Sz1); (double*)(%(Z)s->data), &Sz1);
} }
else { else {
...@@ -181,16 +195,16 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -181,16 +195,16 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
//fprintf(stderr, "alpha=%%f\\n", alpha); //fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy); //fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sger_(&Nz1, &Nz0, &alpha, sger_(&Nz1, &Nz0, &alpha,
(float*)(%(y)s->data), &Sy, (float*)y_data, &Sy,
(float*)(%(x)s->data), &Sx, (float*)x_data, &Sx,
(float*)(%(Z)s->data), &Sz0); (float*)(%(Z)s->data), &Sz0);
} }
else if (%(Z)s->descr->type_num == PyArray_DOUBLE) else if (%(Z)s->descr->type_num == PyArray_DOUBLE)
{ {
double alpha = ((dtype_%(a)s*)%(a)s->data)[0]; double alpha = ((dtype_%(a)s*)%(a)s->data)[0];
dger_(&Nz1, &Nz0, &alpha, dger_(&Nz1, &Nz0, &alpha,
(double*)(%(y)s->data), &Sy, (double*)y_data, &Sy,
(double*)(%(x)s->data), &Sx, (double*)x_data, &Sx,
(double*)(%(Z)s->data), &Sz0); (double*)(%(Z)s->data), &Sz0);
} }
else else
...@@ -206,7 +220,6 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -206,7 +220,6 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
"into a memory-contiguous one."); "into a memory-contiguous one.");
%(fail)s %(fail)s
} }
}
} }
""" % locals() """ % locals()
...@@ -221,7 +234,7 @@ class CGer(BaseBLAS, Ger): ...@@ -221,7 +234,7 @@ class CGer(BaseBLAS, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (7,)
@local_optimizer([ger, ger_destructive]) @local_optimizer([ger, ger_destructive])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论