提交 2741295a authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba cho_solve

上级 1efa92fd
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.core.types import Float from numba.core.types import Complex, Float
from numba.np.linalg import ensure_lapack from numba.np.linalg import ensure_lapack
from scipy import linalg from scipy import linalg
...@@ -32,20 +32,23 @@ def _cho_solve(C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool): ...@@ -32,20 +32,23 @@ def _cho_solve(C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool):
@overload(_cho_solve) @overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False): def cho_solve_impl(C, B, lower=False, overwrite_b=False):
ensure_lapack() ensure_lapack()
_check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve") _check_linalg_matrix(C, ndim=2, dtype=(Float, Complex), func_name="cho_solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve") _check_linalg_matrix(B, ndim=(1, 2), dtype=(Float, Complex), func_name="cho_solve")
_check_dtypes_match((C, B), func_name="cho_solve") _check_dtypes_match((C, B), func_name="cho_solve")
dtype = C.dtype dtype = C.dtype
is_complex = isinstance(dtype, Complex)
numba_potrs = _LAPACK().numba_xpotrs(dtype) numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False): def impl(C, B, lower=False, overwrite_b=False):
_solve_check_input_shapes(C, B) _solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1]) _N = np.int32(C.shape[-1])
if C.flags.f_contiguous or C.flags.c_contiguous: if C.flags.f_contiguous:
C_f = C
elif not is_complex and C.flags.c_contiguous:
# For real triangular factors, c_contiguous L is f_contiguous U (and vice versa).
# Not valid for complex where C^T != C^H.
C_f = C C_f = C
if C.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower lower = not lower
else: else:
C_f = np.asfortranarray(C) C_f = np.asfortranarray(C)
......
...@@ -298,20 +298,34 @@ class TestSolves: ...@@ -298,20 +298,34 @@ class TestSolves:
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"], ids=["b_col_vec", "b_matrix", "b_vec"],
) )
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_cho_solve( def test_cho_solve(
self, b_func, b_shape: tuple[int, ...], lower: bool, overwrite_b: bool self,
b_func,
b_shape: tuple[int, ...],
lower: bool,
overwrite_b: bool,
is_complex: bool,
): ):
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
def A_func(x): def A_func(x):
x = x @ x.conj().T x = x @ x.conj().T
x = scipy.linalg.cholesky(x, lower=lower) x = scipy.linalg.cholesky(x, lower=lower)
return x return x
A = pt.matrix("A", dtype=floatX) A = pt.matrix("A", dtype=dtype)
b = pt.tensor("b", shape=b_shape, dtype=floatX) b = pt.tensor("b", shape=b_shape, dtype=dtype)
rng = np.random.default_rng(418) rng = np.random.default_rng(418)
A_val = A_func(rng.normal(size=(5, 5))).astype(floatX) A_base = rng.normal(size=(5, 5))
b_val = rng.normal(size=b_shape).astype(floatX) if is_complex:
A_base = A_base + 1j * rng.normal(size=(5, 5))
A_val = A_func(A_base).astype(dtype)
b_val = rng.normal(size=b_shape).astype(dtype)
if is_complex:
b_val = b_val + 1j * rng.normal(size=b_shape).astype(dtype)
X = pt.linalg.cho_solve( X = pt.linalg.cho_solve(
(A, lower), (A, lower),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论