提交 afe934b2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid copy of flipped A matrices in GEMV

上级 b2365e0e
...@@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize; int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize; int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s); dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s); dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
// gemv expects pointers to the beginning of memory arrays, // gemv expects pointers to the beginning of memory arrays,
...@@ -435,17 +436,25 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -435,17 +436,25 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
if (NA0 * NA1) if (NA0 * NA1)
{ {
// If A is neither C- nor F-contiguous, we make a copy. if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) )
// TODO:
// - if one stride is equal to "- elemsize", we can still call
// gemv on reversed matrix and vectors
// - if the copy is too long, maybe call vector/vector dot on
// each row instead
if ((PyArray_STRIDES(%(A)s)[0] < 0)
|| (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize)
&& (PyArray_STRIDES(%(A)s)[1] != elemsize)))
{ {
// We can treat the array A as C-or F-contiguous by changing the order of iteration
if (SA0 < 0){
A_data += (NA0 -1) * SA0; // Jump to first row
SA0 = -SA0; // Iterate over rows in reverse
Sz = -Sz; // Iterate over y in reverse
}
if (SA1 < 0){
A_data += (NA1 -1) * SA1; // Jump to first column
SA1 = -SA1; // Iterate over columns in reverse
Sx = -Sx; // Iterate over x in reverse
}
} else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)))
{
// Array isn't contiguous, we have to make a copy
// - if the copy is too long, maybe call vector/vector dot on
// each row instead
// printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\n", SA0, SA1);
npy_intp dims[2]; npy_intp dims[2];
dims[0] = NA0; dims[0] = NA0;
dims[1] = NA1; dims[1] = NA1;
...@@ -458,16 +467,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -458,16 +467,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
%(A)s = A_copy; %(A)s = A_copy;
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1); SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1); SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
} }
if (PyArray_STRIDES(%(A)s)[0] == elemsize) if (SA0 == 1)
{ {
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT) if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{ {
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
sgemv_(&NOTRANS, &NA0, &NA1, sgemv_(&NOTRANS, &NA0, &NA1,
&alpha, &alpha,
(float*)(PyArray_DATA(%(A)s)), &SA1, (float*)(A_data), &SA1,
(float*)x_data, &Sx, (float*)x_data, &Sx,
&fbeta, &fbeta,
(float*)z_data, &Sz); (float*)z_data, &Sz);
...@@ -477,7 +487,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -477,7 +487,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&NOTRANS, &NA0, &NA1, dgemv_(&NOTRANS, &NA0, &NA1,
&alpha, &alpha,
(double*)(PyArray_DATA(%(A)s)), &SA1, (double*)(A_data), &SA1,
(double*)x_data, &Sx, (double*)x_data, &Sx,
&dbeta, &dbeta,
(double*)z_data, &Sz); (double*)z_data, &Sz);
...@@ -489,7 +499,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -489,7 +499,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
%(fail)s %(fail)s
} }
} }
else if (PyArray_STRIDES(%(A)s)[1] == elemsize) else if (SA1 == 1)
{ {
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT) if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{ {
...@@ -506,14 +516,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -506,14 +516,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
z_data[0] = 0.f; z_data[0] = 0.f;
} }
z_data[0] += alpha*sdot_(&NA1, z_data[0] += alpha*sdot_(&NA1,
(float*)(PyArray_DATA(%(A)s)), &SA1, (float*)(A_data), &SA1,
(float*)x_data, &Sx); (float*)x_data, &Sx);
} }
else else
{ {
sgemv_(&TRANS, &NA1, &NA0, sgemv_(&TRANS, &NA1, &NA0,
&alpha, &alpha,
(float*)(PyArray_DATA(%(A)s)), &SA0, (float*)(A_data), &SA0,
(float*)x_data, &Sx, (float*)x_data, &Sx,
&fbeta, &fbeta,
(float*)z_data, &Sz); (float*)z_data, &Sz);
...@@ -534,14 +544,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -534,14 +544,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
z_data[0] = 0.; z_data[0] = 0.;
} }
z_data[0] += alpha*ddot_(&NA1, z_data[0] += alpha*ddot_(&NA1,
(double*)(PyArray_DATA(%(A)s)), &SA1, (double*)(A_data), &SA1,
(double*)x_data, &Sx); (double*)x_data, &Sx);
} }
else else
{ {
dgemv_(&TRANS, &NA1, &NA0, dgemv_(&TRANS, &NA1, &NA0,
&alpha, &alpha,
(double*)(PyArray_DATA(%(A)s)), &SA0, (double*)(A_data), &SA0,
(double*)x_data, &Sx, (double*)x_data, &Sx,
&dbeta, &dbeta,
(double*)z_data, &Sz); (double*)z_data, &Sz);
...@@ -603,7 +613,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -603,7 +613,7 @@ class CGemv(BaseBLAS, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (14, blas_header_version(), check_force_gemv_init()) return (15, blas_header_version(), check_force_gemv_init())
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
......
...@@ -411,3 +411,45 @@ class TestSdotNoFlags(TestCGemvNoFlags): ...@@ -411,3 +411,45 @@ class TestSdotNoFlags(TestCGemvNoFlags):
class TestBlasStridesC(TestBlasStrides): class TestBlasStridesC(TestBlasStrides):
mode = mode_blas_opt mode = mode_blas_opt
@pytest.mark.parametrize(
"neg_stride1", (True, False), ids=["neg_stride1", "pos_stride1"]
)
@pytest.mark.parametrize(
"neg_stride0", (True, False), ids=["neg_stride0", "pos_stride0"]
)
@pytest.mark.parametrize("F_layout", (True, False), ids=["F_layout", "C_layout"])
def test_gemv_negative_strides_perf(neg_stride0, neg_stride1, F_layout, benchmark):
A = pt.matrix("A", shape=(512, 512))
x = pt.vector("x", shape=(A.type.shape[-1],))
y = pt.vector("y", shape=(A.type.shape[0],))
out = CGemv(inplace=False)(
y,
1.0,
A,
x,
1.0,
)
fn = pytensor.function([A, x, y], out, trust_input=True)
rng = np.random.default_rng(430)
test_A = rng.normal(size=A.type.shape)
test_x = rng.normal(size=x.type.shape)
test_y = rng.normal(size=y.type.shape)
if F_layout:
test_A = test_A.T
if neg_stride0:
test_A = test_A[::-1]
if neg_stride1:
test_A = test_A[:, ::-1]
assert (test_A.strides[0] < 0) == neg_stride0
assert (test_A.strides[1] < 0) == neg_stride1
# Check result is correct by using a copy of A with positive strides
res = fn(test_A, test_x, test_y)
np.testing.assert_allclose(res, fn(test_A.copy(), test_x, test_y))
benchmark(fn, test_A, test_x, test_y)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论