提交 8c97bb20 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix Numba pos solve condition number calculation

上级 2e5e38ad
...@@ -884,7 +884,7 @@ def _posv( ...@@ -884,7 +884,7 @@ def _posv(
overwrite_b: bool, overwrite_b: bool,
check_finite: bool, check_finite: bool,
transposed: bool, transposed: bool,
) -> tuple[np.ndarray, int]: ) -> tuple[np.ndarray, np.ndarray, int]:
""" """
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve. Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
""" """
...@@ -901,7 +901,8 @@ def posv_impl( ...@@ -901,7 +901,8 @@ def posv_impl(
check_finite: bool, check_finite: bool,
transposed: bool, transposed: bool,
) -> Callable[ ) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int] [np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
tuple[np.ndarray, np.ndarray, int],
]: ]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(A, "solve") _check_scipy_linalg_matrix(A, "solve")
...@@ -918,7 +919,7 @@ def posv_impl( ...@@ -918,7 +919,7 @@ def posv_impl(
overwrite_b: bool, overwrite_b: bool,
check_finite: bool, check_finite: bool,
transposed: bool, transposed: bool,
) -> tuple[np.ndarray, int]: ) -> tuple[np.ndarray, np.ndarray, int]:
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
_N = np.int32(A.shape[-1]) _N = np.int32(A.shape[-1])
...@@ -962,8 +963,9 @@ def posv_impl( ...@@ -962,8 +963,9 @@ def posv_impl(
) )
if B_is_1d: if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO) B_copy = B_copy[..., 0]
return B_copy, int_ptr_to_val(INFO)
return A_copy, B_copy, int_ptr_to_val(INFO)
return impl return impl
...@@ -1064,10 +1066,12 @@ def solve_psd_impl( ...@@ -1064,10 +1066,12 @@ def solve_psd_impl(
) -> np.ndarray: ) -> np.ndarray:
_solve_check_input_shapes(A, B) _solve_check_input_shapes(A, B)
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed) C, x, info = _posv(
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
)
_solve_check(A.shape[-1], info) _solve_check(A.shape[-1], info)
rcond, info = _pocon(x, _xlange(A)) rcond, info = _pocon(C, _xlange(A))
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond) _solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
return x return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论