提交 8c97bb20 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix Numba pos solve condition number calculation

上级 2e5e38ad
......@@ -884,7 +884,7 @@ def _posv(
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, int]:
) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
......@@ -901,7 +901,8 @@ def posv_impl(
check_finite: bool,
transposed: bool,
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
tuple[np.ndarray, np.ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
......@@ -918,7 +919,7 @@ def posv_impl(
overwrite_b: bool,
check_finite: bool,
transposed: bool,
) -> tuple[np.ndarray, int]:
) -> tuple[np.ndarray, np.ndarray, int]:
_solve_check_input_shapes(A, B)
_N = np.int32(A.shape[-1])
......@@ -962,8 +963,9 @@ def posv_impl(
)
if B_is_1d:
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
B_copy = B_copy[..., 0]
return A_copy, B_copy, int_ptr_to_val(INFO)
return impl
......@@ -1064,10 +1066,12 @@ def solve_psd_impl(
) -> np.ndarray:
_solve_check_input_shapes(A, B)
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
C, x, info = _posv(
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
)
_solve_check(A.shape[-1], info)
rcond, info = _pocon(x, _xlange(A))
rcond, info = _pocon(C, _xlange(A))
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论