Unverified 提交 d8943500 authored 作者: Joren Hammudoglu's avatar Joren Hammudoglu 提交者: GitHub

Add `scipy-stubs` as development depedency (#1598)

* add `scipy-stubs` as development depedency * fix `scipy-stubs` squigglies
上级 f33ea357
......@@ -26,6 +26,7 @@ dependencies:
- diff-cover
- mypy
- types-setuptools
- scipy-stubs
- pytest
- pytest-cov
- pytest-xdist
......
......@@ -28,6 +28,7 @@ dependencies:
- diff-cover
- mypy
- types-setuptools
- scipy-stubs
- pytest
- pytest-cov
- pytest-xdist
......
from collections.abc import Callable
from typing import cast as typing_cast
from typing import Literal
import numpy as np
from numba import njit as numba_njit
......@@ -37,9 +37,9 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def _lu_1(
a: np.ndarray,
permute_l: bool,
permute_l: Literal[True],
check_finite: bool,
p_indices: bool,
p_indices: Literal[False],
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
......@@ -48,23 +48,20 @@ def _lu_1(
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
array of row swaps, such that L[perm] @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
return linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)
def _lu_2(
a: np.ndarray,
permute_l: bool,
permute_l: Literal[False],
check_finite: bool,
p_indices: bool,
p_indices: Literal[True],
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
"""
......@@ -73,23 +70,20 @@ def _lu_2(
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
permuted L matrix, PL = P @ L.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray],
linalg.lu(
return linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)
def _lu_3(
a: np.ndarray,
permute_l: bool,
permute_l: Literal[False],
check_finite: bool,
p_indices: bool,
p_indices: Literal[False],
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
......@@ -98,15 +92,12 @@ def _lu_3(
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
matrix, P @ L @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
return linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)
......
from collections.abc import Callable
from typing import cast as typing_cast
import numpy as np
from numba.core.extending import overload
......@@ -21,8 +22,13 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
returns an info code with diagnostic information.
"""
(getrf,) = linalg.get_lapack_funcs("getrf", (A,))
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
funcs = linalg.get_lapack_funcs("getrf", (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
getrf = funcs[0]
A_copy, ipiv, info = typing_cast(
tuple[np.ndarray, np.ndarray, int], getrf(A, overwrite_a=overwrite_a)
)
return A_copy, ipiv, info
......
from typing import Literal
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
......@@ -13,7 +15,13 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
(geqrf,) = get_lapack_funcs(("geqrf",), (A,))
# (geqrf,) = typing_cast(
# list[Callable[..., np.ndarray]], get_lapack_funcs(("geqrf",), (A,))
# )
funcs = get_lapack_funcs(("geqrf",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
geqrf = funcs[0]
return geqrf(A, overwrite_a=overwrite_a, lwork=lwork)
......@@ -61,7 +69,10 @@ def xgeqrf_impl(A, overwrite_a, lwork):
def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
(geqp3,) = get_lapack_funcs(("geqp3",), (A,))
funcs = get_lapack_funcs(("geqp3",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
geqp3 = funcs[0]
return geqp3(A, overwrite_a=overwrite_a, lwork=lwork)
......@@ -111,7 +122,10 @@ def xgeqp3_impl(A, overwrite_a, lwork):
def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
(orgqr,) = get_lapack_funcs(("orgqr",), (A,))
funcs = get_lapack_funcs(("orgqr",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
orgqr = funcs[0]
return orgqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
......@@ -160,7 +174,10 @@ def xorgqr_impl(A, tau, overwrite_a, lwork):
def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
(ungqr,) = get_lapack_funcs(("ungqr",), (A,))
funcs = get_lapack_funcs(("ungqr",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
ungqr = funcs[0]
return ungqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
......@@ -209,8 +226,8 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
def _qr_full_pivot(
x: np.ndarray,
mode: str = "full",
pivoting: bool = True,
mode: Literal["full", "economic"] = "full",
pivoting: Literal[True] = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
......@@ -234,8 +251,8 @@ def _qr_full_pivot(
def _qr_full_no_pivot(
x: np.ndarray,
mode: str = "full",
pivoting: bool = False,
mode: Literal["full", "economic"] = "full",
pivoting: Literal[False] = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
......@@ -258,8 +275,8 @@ def _qr_full_no_pivot(
def _qr_r_pivot(
x: np.ndarray,
mode: str = "r",
pivoting: bool = True,
mode: Literal["r", "raw"] = "r",
pivoting: Literal[True] = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
......@@ -282,8 +299,8 @@ def _qr_r_pivot(
def _qr_r_no_pivot(
x: np.ndarray,
mode: str = "r",
pivoting: bool = False,
mode: Literal["r", "raw"] = "r",
pivoting: Literal[False] = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
......@@ -306,8 +323,8 @@ def _qr_r_no_pivot(
def _qr_raw_no_pivot(
x: np.ndarray,
mode: str = "raw",
pivoting: bool = False,
mode: Literal["raw"] = "raw",
pivoting: Literal[False] = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
......@@ -332,8 +349,8 @@ def _qr_raw_no_pivot(
def _qr_raw_pivot(
x: np.ndarray,
mode: str = "raw",
pivoting: bool = True,
mode: Literal["raw"] = "raw",
pivoting: Literal[True] = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
......
from collections.abc import Callable
from typing import Literal, TypeAlias
import numpy as np
from numba.core.extending import overload
......@@ -20,8 +21,15 @@ from pytensor.link.numba.dispatch.linalg.utils import (
)
_Trans: TypeAlias = Literal[0, 1, 2]
def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
LU: np.ndarray,
B: np.ndarray,
IPIV: np.ndarray,
trans: _Trans | bool, # mypy does not realize that `bool <: Literal[0, 1]`
overwrite_b: bool,
) -> tuple[np.ndarray, int]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve.
......@@ -31,8 +39,10 @@ def _getrs(
@overload(_getrs)
def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: _Trans, overwrite_b: bool
) -> Callable[
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
......@@ -41,7 +51,11 @@ def getrs_impl(
numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
LU: np.ndarray,
B: np.ndarray,
IPIV: np.ndarray,
trans: _Trans,
overwrite_b: bool,
) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B)
......@@ -89,7 +103,7 @@ def getrs_impl(
def _lu_solve(
lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray,
trans: int,
trans: _Trans,
overwrite_b: bool,
check_finite: bool,
):
......@@ -105,10 +119,10 @@ def _lu_solve(
def lu_solve_impl(
lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray,
trans: int,
trans: _Trans,
overwrite_b: bool,
check_finite: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, bool, bool, bool], np.ndarray]:
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve")
_check_scipy_linalg_matrix(b, "lu_solve")
......@@ -117,7 +131,7 @@ def lu_solve_impl(
lu: np.ndarray,
piv: np.ndarray,
b: np.ndarray,
trans: int,
trans: _Trans,
overwrite_b: bool,
check_finite: bool,
) -> np.ndarray:
......
......@@ -6,15 +6,25 @@ import logging
import sys
import warnings
from math import gcd
from typing import TYPE_CHECKING
import numpy as np
from numpy.exceptions import ComplexWarning
try:
if TYPE_CHECKING:
# https://github.com/scipy/scipy-stubs/issues/851
from scipy.signal._signaltools import ( # type: ignore[attr-defined]
_bvalfromboundary,
_valfrommode,
convolve,
)
from scipy.signal._sigtools import _convolve2d
else:
try:
from scipy.signal.signaltools import _bvalfromboundary, _valfrommode, convolve
from scipy.signal.sigtools import _convolve2d
except ImportError:
except ImportError:
from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve
from scipy.signal._sigtools import _convolve2d
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论