提交 6e06f811 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numba symmetrical solve reciprocal of condition number

上级 8a81a53d
......@@ -653,7 +653,7 @@ def solve_gen_impl(
def _sysv(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> tuple[np.ndarray, np.ndarray, int]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
......@@ -664,7 +664,8 @@ def _sysv(
def sysv_impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int]
[np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "sysv")
......@@ -740,8 +741,8 @@ def sysv_impl(
)
if B_is_1d:
return B_copy[..., 0], IPIV, int_ptr_to_val(INFO)
return B_copy, IPIV, int_ptr_to_val(INFO)
B_copy = B_copy[..., 0]
return A_copy, B_copy, IPIV, int_ptr_to_val(INFO)
return impl
......@@ -770,7 +771,7 @@ def sycon_impl(
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("L"))
UPLO = val_to_int_ptr(ord("U"))
ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype)
......@@ -843,10 +844,10 @@ def solve_symmetric_impl(
) -> np.ndarray:
_solve_check_input_shapes(A, B)
x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
_solve_check(A.shape[-1], info)
rcond, info = _sycon(A, ipiv, _xlange(A, order="I"))
rcond, info = _sycon(lu, ipiv, _xlange(A, order="I"))
_solve_check(A.shape[-1], info, True, rcond)
return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论