提交 e75bbb2c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove unnecessary overloads

Inline when only used in one place, or remove if altogether unused
上级 672a4829
......@@ -4,225 +4,16 @@ import numpy as np
from numba.core.extending import overload
from numba.core.types import Complex, Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy.linalg import get_lapack_funcs, qr
from scipy.linalg import qr
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix
def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
# (geqrf,) = typing_cast(
# list[Callable[..., np.ndarray]], get_lapack_funcs(("geqrf",), (A,))
# )
funcs = get_lapack_funcs(("geqrf",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
geqrf = funcs[0]
return geqrf(A, overwrite_a=overwrite_a, lwork=lwork)
@overload(_xgeqrf)
def xgeqrf_impl(A, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl(A, overwrite_a, lwork):
M = np.int32(A.shape[0])
N = np.int32(A.shape[1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
LDA = val_to_int_ptr(M)
TAU = np.empty(min(M, N), dtype=dtype)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
INFO = val_to_int_ptr(1)
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
A_copy.ctypes,
LDA,
TAU.ctypes,
WORK.ctypes,
LWORK,
INFO,
)
return A_copy, TAU, WORK, int_ptr_to_val(INFO)
return impl
def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
funcs = get_lapack_funcs(("geqp3",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
geqp3 = funcs[0]
return geqp3(A, overwrite_a=overwrite_a, lwork=lwork)
@overload(_xgeqp3)
def xgeqp3_impl(A, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
geqp3 = _LAPACK().numba_xgeqp3(dtype)
def impl(A, overwrite_a, lwork):
M = np.int32(A.shape[0])
N = np.int32(A.shape[1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
LDA = val_to_int_ptr(M)
JPVT = np.zeros(N, dtype=np.int32)
TAU = np.empty(min(M, N), dtype=dtype)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
INFO = val_to_int_ptr(1)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
A_copy.ctypes,
LDA,
JPVT.ctypes,
TAU.ctypes,
WORK.ctypes,
LWORK,
INFO,
)
return A_copy, JPVT, TAU, WORK, int_ptr_to_val(INFO)
return impl
def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
funcs = get_lapack_funcs(("orgqr",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
orgqr = funcs[0]
return orgqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
@overload(_xorgqr)
def xorgqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
orgqr = _LAPACK().numba_xorgqr(dtype)
def impl(A, tau, overwrite_a, lwork):
M = np.int32(A.shape[0])
N = np.int32(A.shape[1])
K = np.int32(tau.shape[0])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
LDA = val_to_int_ptr(M)
INFO = val_to_int_ptr(1)
orgqr(
val_to_int_ptr(M),
val_to_int_ptr(N),
val_to_int_ptr(K),
A_copy.ctypes,
LDA,
tau.ctypes,
WORK.ctypes,
LWORK,
INFO,
)
return A_copy, WORK, int_ptr_to_val(INFO)
return impl
def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
funcs = get_lapack_funcs(("ungqr",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
ungqr = funcs[0]
return ungqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
@overload(_xungqr)
def xungqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="qr")
dtype = A.dtype
ungqr = _LAPACK().numba_xungqr(dtype)
def impl(A, tau, overwrite_a, lwork):
M = np.int32(A.shape[0])
N = np.int32(A.shape[1])
K = np.int32(tau.shape[0])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
LDA = val_to_int_ptr(M)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
INFO = val_to_int_ptr(1)
ungqr(
val_to_int_ptr(M),
val_to_int_ptr(N),
val_to_int_ptr(K),
A_copy.ctypes,
LDA,
tau.ctypes,
WORK.ctypes,
LWORK,
INFO,
)
return A_copy, WORK, int_ptr_to_val(INFO)
return impl
def _qr_full_pivot(
x: np.ndarray,
mode: Literal["full", "economic"] = "full",
......
......@@ -19,36 +19,42 @@ from pytensor.link.numba.dispatch.linalg.utils import (
)
def _posv(
def _solve_psd(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
) -> tuple[np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
return # type: ignore
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
avoid unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
transposed=transposed,
assume_a="pos",
)
@overload(_posv)
def posv_impl(
@overload(_solve_psd)
def solve_psd_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, int],
]:
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
dtype = A.dtype
numba_posv = _LAPACK().numba_xposv(dtype)
numba_posv = _LAPACK().numba_xposv(A.dtype)
def impl(
A: np.ndarray,
......@@ -56,9 +62,9 @@ def posv_impl(
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
) -> tuple[np.ndarray, np.ndarray, int]:
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
_N = np.int32(A.shape[-1])
if overwrite_a and (A.flags.f_contiguous or A.flags.c_contiguous):
......@@ -102,62 +108,9 @@ def posv_impl(
if B_is_1d:
B_copy = B_copy[..., 0]
return A_copy, B_copy, int_ptr_to_val(INFO)
return impl
def _solve_psd(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
avoid unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
transposed=transposed,
assume_a="pos",
)
@overload(_solve_psd)
def solve_psd_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
_C, x, info = _posv(A, B, lower, overwrite_a, overwrite_b)
if info != 0:
x = np.full_like(x, np.nan)
if int_ptr_to_val(INFO) != 0:
B_copy = np.full_like(B_copy, np.nan)
return x
return B_copy
return impl
......@@ -19,32 +19,52 @@ from pytensor.link.numba.dispatch.linalg.utils import (
)
def _sysv(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
return # type: ignore
@overload(_sysv)
def sysv_impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
) -> Callable[
[np.ndarray, np.ndarray, bool, bool, bool],
tuple[np.ndarray, np.ndarray, np.ndarray, int],
]:
def _solve_symmetric(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
assume_a="sym",
transposed=transposed,
)
@overload(_solve_symmetric)
def solve_symmetric_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="sysv")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="sysv")
_check_dtypes_match((A, B), func_name="sysv")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
dtype = A.dtype
numba_sysv = _LAPACK().numba_xsysv(dtype)
numba_sysv = _LAPACK().numba_xsysv(A.dtype)
def impl(
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
):
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
) -> np.ndarray:
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
_solve_check_input_shapes(A, B)
......@@ -112,64 +132,12 @@ def sysv_impl(
INFO,
)
if int_ptr_to_val(INFO) != 0:
B_copy = np.full_like(B_copy, np.nan)
if B_is_1d:
B_copy = B_copy[..., 0]
return A_copy, B_copy, IPIV, int_ptr_to_val(INFO)
return impl
def _solve_symmetric(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return linalg.solve(
A,
B,
lower=lower,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b,
check_finite=False,
assume_a="sym",
transposed=transposed,
)
@overload(_solve_symmetric)
def solve_symmetric_impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), func_name="solve")
def impl(
A: np.ndarray,
B: np.ndarray,
lower: bool,
overwrite_a: bool,
overwrite_b: bool,
transposed: bool,
) -> np.ndarray:
_solve_check_input_shapes(A, B)
_lu, x, _ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
if info != 0:
x = np.full_like(x, np.nan)
return x
return B_copy
return impl
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论