提交 2741295a authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

allow complex inputs to numba cho_solve

上级 1efa92fd
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.core.types import Complex, Float
from numba.np.linalg import ensure_lapack
from scipy import linalg
......@@ -32,21 +32,24 @@ def _cho_solve(C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool):
@overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False):
ensure_lapack()
_check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve")
_check_linalg_matrix(C, ndim=2, dtype=(Float, Complex), func_name="cho_solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=(Float, Complex), func_name="cho_solve")
_check_dtypes_match((C, B), func_name="cho_solve")
dtype = C.dtype
is_complex = isinstance(dtype, Complex)
numba_potrs = _LAPACK().numba_xpotrs(dtype)
def impl(C, B, lower=False, overwrite_b=False):
_solve_check_input_shapes(C, B)
_N = np.int32(C.shape[-1])
if C.flags.f_contiguous or C.flags.c_contiguous:
if C.flags.f_contiguous:
C_f = C
if C.flags.c_contiguous:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower = not lower
elif not is_complex and C.flags.c_contiguous:
# For real triangular factors, c_contiguous L is f_contiguous U (and vice versa).
# Not valid for complex where C^T != C^H.
C_f = C
lower = not lower
else:
C_f = np.asfortranarray(C)
......
......@@ -298,20 +298,34 @@ class TestSolves:
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_cho_solve(
self, b_func, b_shape: tuple[int, ...], lower: bool, overwrite_b: bool
self,
b_func,
b_shape: tuple[int, ...],
lower: bool,
overwrite_b: bool,
is_complex: bool,
):
complex_dtype = "complex64" if floatX.endswith("32") else "complex128"
dtype = complex_dtype if is_complex else floatX
def A_func(x):
x = x @ x.conj().T
x = scipy.linalg.cholesky(x, lower=lower)
return x
A = pt.matrix("A", dtype=floatX)
b = pt.tensor("b", shape=b_shape, dtype=floatX)
A = pt.matrix("A", dtype=dtype)
b = pt.tensor("b", shape=b_shape, dtype=dtype)
rng = np.random.default_rng(418)
A_val = A_func(rng.normal(size=(5, 5))).astype(floatX)
b_val = rng.normal(size=b_shape).astype(floatX)
A_base = rng.normal(size=(5, 5))
if is_complex:
A_base = A_base + 1j * rng.normal(size=(5, 5))
A_val = A_func(A_base).astype(dtype)
b_val = rng.normal(size=b_shape).astype(dtype)
if is_complex:
b_val = b_val + 1j * rng.normal(size=b_shape).astype(dtype)
X = pt.linalg.cho_solve(
(A, lower),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论