提交 afe934b2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid copy of flipped A matrices in GEMV

上级 b2365e0e
......@@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
// gemv expects pointers to the beginning of memory arrays,
......@@ -435,17 +436,25 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
if (NA0 * NA1)
{
// If A is neither C- nor F-contiguous, we make a copy.
// TODO:
// - if one stride is equal to "- elemsize", we can still call
// gemv on reversed matrix and vectors
// - if the copy is too long, maybe call vector/vector dot on
// each row instead
if ((PyArray_STRIDES(%(A)s)[0] < 0)
|| (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize)
&& (PyArray_STRIDES(%(A)s)[1] != elemsize)))
if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) )
{
// We can treat the array A as C-or F-contiguous by changing the order of iteration
if (SA0 < 0){
A_data += (NA0 -1) * SA0; // Jump to first row
SA0 = -SA0; // Iterate over rows in reverse
Sz = -Sz; // Iterate over y in reverse
}
if (SA1 < 0){
A_data += (NA1 -1) * SA1; // Jump to first column
SA1 = -SA1; // Iterate over columns in reverse
Sx = -Sx; // Iterate over x in reverse
}
} else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)))
{
// Array isn't contiguous, we have to make a copy
// - if the copy is too long, maybe call vector/vector dot on
// each row instead
// printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\n", SA0, SA1);
npy_intp dims[2];
dims[0] = NA0;
dims[1] = NA1;
......@@ -458,16 +467,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
%(A)s = A_copy;
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
}
if (PyArray_STRIDES(%(A)s)[0] == elemsize)
if (SA0 == 1)
{
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
sgemv_(&NOTRANS, &NA0, &NA1,
&alpha,
(float*)(PyArray_DATA(%(A)s)), &SA1,
(float*)(A_data), &SA1,
(float*)x_data, &Sx,
&fbeta,
(float*)z_data, &Sz);
......@@ -477,7 +487,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&NOTRANS, &NA0, &NA1,
&alpha,
(double*)(PyArray_DATA(%(A)s)), &SA1,
(double*)(A_data), &SA1,
(double*)x_data, &Sx,
&dbeta,
(double*)z_data, &Sz);
......@@ -489,7 +499,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
%(fail)s
}
}
else if (PyArray_STRIDES(%(A)s)[1] == elemsize)
else if (SA1 == 1)
{
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{
......@@ -506,14 +516,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
z_data[0] = 0.f;
}
z_data[0] += alpha*sdot_(&NA1,
(float*)(PyArray_DATA(%(A)s)), &SA1,
(float*)(A_data), &SA1,
(float*)x_data, &Sx);
}
else
{
sgemv_(&TRANS, &NA1, &NA0,
&alpha,
(float*)(PyArray_DATA(%(A)s)), &SA0,
(float*)(A_data), &SA0,
(float*)x_data, &Sx,
&fbeta,
(float*)z_data, &Sz);
......@@ -534,14 +544,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
z_data[0] = 0.;
}
z_data[0] += alpha*ddot_(&NA1,
(double*)(PyArray_DATA(%(A)s)), &SA1,
(double*)(A_data), &SA1,
(double*)x_data, &Sx);
}
else
{
dgemv_(&TRANS, &NA1, &NA0,
&alpha,
(double*)(PyArray_DATA(%(A)s)), &SA0,
(double*)(A_data), &SA0,
(double*)x_data, &Sx,
&dbeta,
(double*)z_data, &Sz);
......@@ -603,7 +613,7 @@ class CGemv(BaseBLAS, Gemv):
return code
def c_code_cache_version(self):
return (14, blas_header_version(), check_force_gemv_init())
return (15, blas_header_version(), check_force_gemv_init())
cgemv_inplace = CGemv(inplace=True)
......
......@@ -411,3 +411,45 @@ class TestSdotNoFlags(TestCGemvNoFlags):
class TestBlasStridesC(TestBlasStrides):
mode = mode_blas_opt
@pytest.mark.parametrize(
"neg_stride1", (True, False), ids=["neg_stride1", "pos_stride1"]
)
@pytest.mark.parametrize(
"neg_stride0", (True, False), ids=["neg_stride0", "pos_stride0"]
)
@pytest.mark.parametrize("F_layout", (True, False), ids=["F_layout", "C_layout"])
def test_gemv_negative_strides_perf(neg_stride0, neg_stride1, F_layout, benchmark):
A = pt.matrix("A", shape=(512, 512))
x = pt.vector("x", shape=(A.type.shape[-1],))
y = pt.vector("y", shape=(A.type.shape[0],))
out = CGemv(inplace=False)(
y,
1.0,
A,
x,
1.0,
)
fn = pytensor.function([A, x, y], out, trust_input=True)
rng = np.random.default_rng(430)
test_A = rng.normal(size=A.type.shape)
test_x = rng.normal(size=x.type.shape)
test_y = rng.normal(size=y.type.shape)
if F_layout:
test_A = test_A.T
if neg_stride0:
test_A = test_A[::-1]
if neg_stride1:
test_A = test_A[:, ::-1]
assert (test_A.strides[0] < 0) == neg_stride0
assert (test_A.strides[1] < 0) == neg_stride1
# Check result is correct by using a copy of A with positive strides
res = fn(test_A, test_x, test_y)
np.testing.assert_allclose(res, fn(test_A.copy(), test_x, test_y))
benchmark(fn, test_A, test_x, test_y)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论