提交 6e06f811 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numba symmetrical solve reciprocal of condition number

上级 8a81a53d
...@@ -653,7 +653,7 @@ def solve_gen_impl( ...@@ -653,7 +653,7 @@ def solve_gen_impl(
def _sysv( def _sysv(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> tuple[np.ndarray, np.ndarray, int]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
""" """
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve. Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
""" """
...@@ -664,7 +664,8 @@ def _sysv( ...@@ -664,7 +664,8 @@ def _sysv(
def sysv_impl( def sysv_impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> Callable[ ) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int] [np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, np.ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "sysv") _check_scipy_linalg_matrix(A, "sysv")
...@@ -740,8 +741,8 @@ def sysv_impl( ...@@ -740,8 +741,8 @@ def sysv_impl(
) )
if B_is_1d: if B_is_1d:
return B_copy[..., 0], IPIV, int_ptr_to_val(INFO) B_copy = B_copy[..., 0]
return B_copy, IPIV, int_ptr_to_val(INFO) return A_copy, B_copy, IPIV, int_ptr_to_val(INFO)
return impl return impl
...@@ -770,7 +771,7 @@ def sycon_impl( ...@@ -770,7 +771,7 @@ def sycon_impl(
N = val_to_int_ptr(_N) N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N) LDA = val_to_int_ptr(_N)
UPLO = val_to_int_ptr(ord("L")) UPLO = val_to_int_ptr(ord("U"))
ANORM = np.array(anorm, dtype=dtype) ANORM = np.array(anorm, dtype=dtype)
RCOND = np.empty(1, dtype=dtype) RCOND = np.empty(1, dtype=dtype)
WORK = np.empty(2 * _N, dtype=dtype) WORK = np.empty(2 * _N, dtype=dtype)
...@@ -843,10 +844,10 @@ def solve_symmetric_impl( ...@@ -843,10 +844,10 @@ def solve_symmetric_impl(
) -> np.ndarray: ) -> np.ndarray:
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b) lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
_solve_check(A.shape[-1], info) _solve_check(A.shape[-1], info)
rcond, info = _sycon(A, ipiv, _xlange(A, order="I")) rcond, info = _sycon(lu, ipiv, _xlange(A, order="I"))
_solve_check(A.shape[-1], info, True, rcond) _solve_check(A.shape[-1], info, True, rcond)
return x return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论