提交 709f745c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Fix bug in handling of row/column matrices in GEMV c_code

Bug was caused by reusing the adjusted strides in the logic to decide whether the call to GEMV should be transposed or not. Particularly the +1 in the strides variable was causing the error branch (no double-strides) to be reached wrongly. The +1 was supposedly there for the case of matrix with length 0, but that triggers a branch where the adjusted strides are never used. This bug was introduced in afe934b2
上级 86cc3b87
...@@ -344,6 +344,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -344,6 +344,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
""" """
code = """ code = """
bool is_float;
int elemsize; int elemsize;
float fbeta; float fbeta;
double dbeta; double dbeta;
...@@ -361,11 +362,23 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -361,11 +362,23 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
%(fail)s; %(fail)s;
} }
if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) { elemsize = 8; } if ((PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(x)s)->type_num)
else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) { elemsize = 4;} || (PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(A)s)->type_num))
{
PyErr_SetString(PyExc_TypeError, "GEMV: dtypes of A, x, y do not match");
%(fail)s;
}
if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) {
is_float = 0;
elemsize = 8;
}
else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) {
elemsize = 4;
is_float = 1;
}
else { else {
PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
%(fail)s; %(fail)s;
PyErr_SetString(PyExc_NotImplementedError, "GEMV: Inputs must be float or double");
} }
fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0]; fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
...@@ -408,37 +421,40 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -408,37 +421,40 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
Py_INCREF(%(z)s); Py_INCREF(%(z)s);
} }
} }
{ {
char TRANS = 'T';
char NOTRANS = 'N';
int NA0 = PyArray_DIMS(%(A)s)[0]; int NA0 = PyArray_DIMS(%(A)s)[0];
int NA1 = PyArray_DIMS(%(A)s)[1]; int NA1 = PyArray_DIMS(%(A)s)[1];
/* This formula is needed in the case where A is actually a row or
* column matrix, because BLAS sometimes insists that the strides:
* - are not smaller than the number of elements in the array
* - are not 0.
*/
int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (NA1 - 1) * Sx;
if (Sz < 0)
z_data += (NA0 - 1) * Sz;
if (NA0 * NA1) if (NA0 * NA1)
{ {
// Non-empty A matrix
/* In the case where A is actually a row or column matrix,
* the strides corresponding to the dummy dimension don't matter,
* but BLAS requires these to be no smaller than the number of elements in the array.
*/
int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1;
int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0;
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (NA1 - 1) * Sx;
if (Sz < 0)
z_data += (NA0 - 1) * Sz;
if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) ) if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) )
{ {
// We can treat the array A as C-or F-contiguous by changing the order of iteration // We can treat the array A as C-or F-contiguous by changing the order of iteration
// printf("GEMV: Iterating in reverse NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\n", NA0, NA1, SA0, SA1);
if (SA0 < 0){ if (SA0 < 0){
A_data += (NA0 -1) * SA0; // Jump to first row A_data += (NA0 -1) * SA0; // Jump to first row
SA0 = -SA0; // Iterate over rows in reverse SA0 = -SA0; // Iterate over rows in reverse
...@@ -452,27 +468,45 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -452,27 +468,45 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
} else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1))) } else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)))
{ {
// Array isn't contiguous, we have to make a copy // Array isn't contiguous, we have to make a copy
// - if the copy is too long, maybe call vector/vector dot on // - if the copy is too long, maybe call vector/vector dot on each row instead
// each row instead // printf("GEMV: Making a copy NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\n", NA0, NA1, SA0, SA1);
// printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\n", SA0, SA1);
npy_intp dims[2]; npy_intp dims[2];
dims[0] = NA0; dims[0] = NA0;
dims[1] = NA1; dims[1] = NA1;
PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(%(A)s);
PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(
%(A)s);
if (!A_copy) if (!A_copy)
%(fail)s %(fail)s
Py_XDECREF(%(A)s); Py_XDECREF(%(A)s);
%(A)s = A_copy; %(A)s = A_copy;
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1); SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1;
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1); SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0;
A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s); A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
} }
//else {printf("GEMV: Using the original array NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\n", NA0, NA1, SA0, SA1);}
if (SA0 == 1) if (NA0 == 1)
{
// Vector-vector dot product, it seems faster to avoid GEMV
dtype_%(alpha)s alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
if (is_float)
{
z_data[0] *= fbeta;
z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1,
(float*)x_data, &Sx);
}
else
{
z_data[0] *= dbeta;
z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1,
(double*)x_data, &Sx);
}
}
else if (SA0 == 1)
{ {
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT) // F-contiguous
char NOTRANS = 'N';
if (is_float)
{ {
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
sgemv_(&NOTRANS, &NA0, &NA1, sgemv_(&NOTRANS, &NA0, &NA1,
...@@ -482,7 +516,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -482,7 +516,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
&fbeta, &fbeta,
(float*)z_data, &Sz); (float*)z_data, &Sz);
} }
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE) else
{ {
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&NOTRANS, &NA0, &NA1, dgemv_(&NOTRANS, &NA0, &NA1,
...@@ -492,97 +526,39 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -492,97 +526,39 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
&dbeta, &dbeta,
(double*)z_data, &Sz); (double*)z_data, &Sz);
} }
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
} }
else if (SA1 == 1) else if (SA1 == 1)
{ {
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT) // C-contiguous
char TRANS = 'T';
if (is_float)
{ {
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
sgemv_(&TRANS, &NA1, &NA0,
// Check for vector-vector dot (NA0 == 1). The code may work &alpha,
// for SA1 != 1 as well, but has not been tested for this case, (float*)(A_data), &SA0,
// so SA1 == 1 is required for safety. (float*)x_data, &Sx,
if (NA0 == 1 && SA1 == 1) &fbeta,
{ (float*)z_data, &Sz);
if (fbeta != 0.f) {
z_data[0] = fbeta*z_data[0];
} else {
z_data[0] = 0.f;
}
z_data[0] += alpha*sdot_(&NA1,
(float*)(A_data), &SA1,
(float*)x_data, &Sx);
}
else
{
sgemv_(&TRANS, &NA1, &NA0,
&alpha,
(float*)(A_data), &SA0,
(float*)x_data, &Sx,
&fbeta,
(float*)z_data, &Sz);
}
}
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
// Check for vector-vector dot (NA0 == 1). The code may work
// for SA1 != 1 as well, but has not been tested for this case,
// so SA1 == 1 is required for safety.
if (NA0 == 1 && SA1 == 1)
{
if (dbeta != 0.) {
z_data[0] = dbeta*z_data[0];
} else {
z_data[0] = 0.;
}
z_data[0] += alpha*ddot_(&NA1,
(double*)(A_data), &SA1,
(double*)x_data, &Sx);
}
else
{
dgemv_(&TRANS, &NA1, &NA0,
&alpha,
(double*)(A_data), &SA0,
(double*)x_data, &Sx,
&dbeta,
(double*)z_data, &Sz);
}
} }
else else
{ {
PyErr_SetString(PyExc_AssertionError, double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
"neither float nor double dtype"); dgemv_(&TRANS, &NA1, &NA0,
%(fail)s &alpha,
(double*)(A_data), &SA0,
(double*)x_data, &Sx,
&dbeta,
(double*)z_data, &Sz);
} }
} }
else else
{ {
PyErr_SetString(PyExc_AssertionError, PyErr_SetString(PyExc_AssertionError,
"xx is a double-strided matrix, and should have been " "A is neither C nor F-contiguous, it should have been copied into a memory-contiguous array;");
"copied into a memory-contiguous one.");
%(fail)s %(fail)s
} }
} }
else if (dbeta != 1.0)
{
// the matrix has at least one dim of length 0
// so we do this loop, which either iterates over 0 elements
// or else it does the right thing for length-0 A.
dtype_%(z)s * zptr = (dtype_%(z)s*)(PyArray_DATA(%(z)s));
for (int i = 0; i < NA0; ++i)
{
zptr[i * Sz] = (dbeta == 0.0 ? 0.0 : zptr[i * Sz] * dbeta);
}
}
} }
""" """
return code % locals() return code % locals()
...@@ -613,7 +589,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -613,7 +589,7 @@ class CGemv(BaseBLAS, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (15, blas_header_version(), check_force_gemv_init()) return (16, blas_header_version(), check_force_gemv_init())
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
......
...@@ -2226,8 +2226,10 @@ class TestBlasStrides: ...@@ -2226,8 +2226,10 @@ class TestBlasStrides:
a.set_value(a_dev.copy()[::a_step], borrow=True) a.set_value(a_dev.copy()[::a_step], borrow=True)
b.set_value(b_dev.copy()[::b_step1, ::b_step2], borrow=True) b.set_value(b_dev.copy()[::b_step1, ::b_step2], borrow=True)
# Copy as C so that it becomes F after the transpose in the graph
b_t.set_value( b_t.set_value(
np.transpose(b_dev.copy())[::b_step2, ::b_step1], borrow=True np.transpose(b_dev).copy(order="C")[::b_step2, ::b_step1],
borrow=True,
) )
c.set_value(c_dev.copy()[::c_step], borrow=True) c.set_value(c_dev.copy()[::c_step], borrow=True)
...@@ -2244,6 +2246,7 @@ class TestBlasStrides: ...@@ -2244,6 +2246,7 @@ class TestBlasStrides:
self.cmp_gemv(3, (3, 5), 5, rng) self.cmp_gemv(3, (3, 5), 5, rng)
self.cmp_gemv(1, (1, 5), 5, rng) self.cmp_gemv(1, (1, 5), 5, rng)
self.cmp_gemv(3, (3, 1), 1, rng) self.cmp_gemv(3, (3, 1), 1, rng)
self.cmp_gemv(1, (1, 1), 1, rng)
self.cmp_gemv(0, (0, 5), 5, rng) self.cmp_gemv(0, (0, 5), 5, rng)
self.cmp_gemv(3, (3, 0), 0, rng) self.cmp_gemv(3, (3, 0), 0, rng)
self.cmp_gemv(0, (0, 1), 1, rng) self.cmp_gemv(0, (0, 1), 1, rng)
...@@ -2301,6 +2304,7 @@ class TestBlasStrides: ...@@ -2301,6 +2304,7 @@ class TestBlasStrides:
self.cmp_ger((3, 5), 3, 5, rng) self.cmp_ger((3, 5), 3, 5, rng)
self.cmp_ger((1, 5), 1, 5, rng) self.cmp_ger((1, 5), 1, 5, rng)
self.cmp_ger((3, 1), 3, 1, rng) self.cmp_ger((3, 1), 3, 1, rng)
self.cmp_ger((1, 1), 1, 1, rng)
self.cmp_ger((0, 5), 0, 5, rng) self.cmp_ger((0, 5), 0, 5, rng)
self.cmp_ger((3, 0), 3, 0, rng) self.cmp_ger((3, 0), 3, 0, rng)
self.cmp_ger((0, 1), 0, 1, rng) self.cmp_ger((0, 1), 0, 1, rng)
......
...@@ -243,6 +243,7 @@ class TestCGemv(OptimizationTestMixin): ...@@ -243,6 +243,7 @@ class TestCGemv(OptimizationTestMixin):
self.t_gemv1((0, 2)) self.t_gemv1((0, 2))
self.t_gemv1((3, 1)) self.t_gemv1((3, 1))
self.t_gemv1((3, 0)) self.t_gemv1((3, 0))
self.t_gemv1((1, 1))
self.t_gemv1((1, 0)) self.t_gemv1((1, 0))
self.t_gemv1((0, 1)) self.t_gemv1((0, 1))
self.t_gemv1((0, 0)) self.t_gemv1((0, 0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论