提交 2e5e38ad authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid copying C-contiguous arrays in solve methods

上级 0fd8315f
...@@ -126,6 +126,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -126,6 +126,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
dtype = A.dtype dtype = A.dtype
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
if isinstance(dtype, types.Complex):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
raise TypeError("This function is not expected to work with complex numbers")
def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b): def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
...@@ -135,7 +138,14 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -135,7 +138,14 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim) # could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d = B.ndim == 1 B_is_1d = B.ndim == 1
# This will only copy if A is not already fortran contiguous if A.flags.f_contiguous or (A.flags.c_contiguous and trans in (0, 1)):
A_f = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
# Is this valid for complex matrices that were .conj().mT by PyTensor?
lower = not lower
trans = 1 - trans
else:
A_f = np.asfortranarray(A) A_f = np.asfortranarray(A)
if overwrite_b and B.flags.f_contiguous: if overwrite_b and B.flags.f_contiguous:
...@@ -633,6 +643,11 @@ def solve_gen_impl( ...@@ -633,6 +643,11 @@ def solve_gen_impl(
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
if overwrite_a and A.flags.c_contiguous:
# Work with the transposed system to avoid copying A
A = A.T
transposed = not transposed
order = "I" if transposed else "1" order = "I" if transposed else "1"
norm = _xlange(A, order=order) norm = _xlange(A, order=order)
...@@ -682,8 +697,11 @@ def sysv_impl( ...@@ -682,8 +697,11 @@ def sysv_impl(
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore _LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
if overwrite_a and A.flags.f_contiguous: if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
A_copy = A A_copy = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else: else:
A_copy = _copy_to_fortran_order(A) A_copy = _copy_to_fortran_order(A)
...@@ -905,8 +923,11 @@ def posv_impl( ...@@ -905,8 +923,11 @@ def posv_impl(
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
if overwrite_a and A.flags.f_contiguous: if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
A_copy = A A_copy = A
if A.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else: else:
A_copy = _copy_to_fortran_order(A) A_copy = _copy_to_fortran_order(A)
...@@ -1128,6 +1149,12 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True): ...@@ -1128,6 +1149,12 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
_solve_check_input_shapes(C, B) _solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1]) _N = np.int32(C.shape[-1])
if C.flags.f_contiguous or C.flags.c_contiguous:
C_f = C
if C.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
else:
C_f = np.asfortranarray(C) C_f = np.asfortranarray(C)
if overwrite_b and B.flags.f_contiguous: if overwrite_b and B.flags.f_contiguous:
......
...@@ -169,7 +169,8 @@ class TestSolves: ...@@ -169,7 +169,8 @@ class TestSolves:
b_val_c_contig = np.copy(b_val, order="C") b_val_c_contig = np.copy(b_val, order="C")
res_c_contig = f(A_val_c_contig, b_val_c_contig) res_c_contig = f(A_val_c_contig, b_val_c_contig)
np.testing.assert_allclose(res_c_contig, res) np.testing.assert_allclose(res_c_contig, res)
np.testing.assert_allclose(A_val_c_contig, A_val) # We can destroy C-contiguous A arrays by inverting `tranpose/lower` at runtime
assert np.allclose(A_val_c_contig, A_val) == (not overwrite_a)
# b vectors are always f_contiguous if also c_contiguous # b vectors are always f_contiguous if also c_contiguous
assert np.allclose(b_val_c_contig, b_val) == ( assert np.allclose(b_val_c_contig, b_val) == (
not (overwrite_b and b_val_c_contig.flags.f_contiguous) not (overwrite_b and b_val_c_contig.flags.f_contiguous)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论