Unverified 提交 d8943500 authored 作者: Joren Hammudoglu's avatar Joren Hammudoglu 提交者: GitHub

Add `scipy-stubs` as development depedency (#1598)

* add `scipy-stubs` as development depedency * fix `scipy-stubs` squigglies
上级 f33ea357
...@@ -26,6 +26,7 @@ dependencies: ...@@ -26,6 +26,7 @@ dependencies:
- diff-cover - diff-cover
- mypy - mypy
- types-setuptools - types-setuptools
- scipy-stubs
- pytest - pytest
- pytest-cov - pytest-cov
- pytest-xdist - pytest-xdist
......
...@@ -28,6 +28,7 @@ dependencies: ...@@ -28,6 +28,7 @@ dependencies:
- diff-cover - diff-cover
- mypy - mypy
- types-setuptools - types-setuptools
- scipy-stubs
- pytest - pytest
- pytest-cov - pytest-cov
- pytest-xdist - pytest-xdist
......
from collections.abc import Callable from collections.abc import Callable
from typing import cast as typing_cast from typing import Literal
import numpy as np import numpy as np
from numba import njit as numba_njit from numba import njit as numba_njit
...@@ -37,9 +37,9 @@ def _lu_factor_to_lu(a, dtype, overwrite_a): ...@@ -37,9 +37,9 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def _lu_1( def _lu_1(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: Literal[True],
check_finite: bool, check_finite: bool,
p_indices: bool, p_indices: Literal[False],
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
""" """
...@@ -48,23 +48,20 @@ def _lu_1( ...@@ -48,23 +48,20 @@ def _lu_1(
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
array of row swaps, such that L[perm] @ U = A. array of row swaps, such that L[perm] @ U = A.
""" """
return typing_cast( return linalg.lu(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite, check_finite=check_finite,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
),
) )
def _lu_2( def _lu_2(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: Literal[False],
check_finite: bool, check_finite: bool,
p_indices: bool, p_indices: Literal[True],
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
...@@ -73,23 +70,20 @@ def _lu_2( ...@@ -73,23 +70,20 @@ def _lu_2(
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
permuted L matrix, PL = P @ L. permuted L matrix, PL = P @ L.
""" """
return typing_cast( return linalg.lu(
tuple[np.ndarray, np.ndarray],
linalg.lu(
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite, check_finite=check_finite,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
),
) )
def _lu_3( def _lu_3(
a: np.ndarray, a: np.ndarray,
permute_l: bool, permute_l: Literal[False],
check_finite: bool, check_finite: bool,
p_indices: bool, p_indices: Literal[False],
overwrite_a: bool, overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
""" """
...@@ -98,15 +92,12 @@ def _lu_3( ...@@ -98,15 +92,12 @@ def _lu_3(
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
matrix, P @ L @ U = A. matrix, P @ L @ U = A.
""" """
return typing_cast( return linalg.lu(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a, a,
permute_l=permute_l, permute_l=permute_l,
check_finite=check_finite, check_finite=check_finite,
p_indices=p_indices, p_indices=p_indices,
overwrite_a=overwrite_a, overwrite_a=overwrite_a,
),
) )
......
from collections.abc import Callable from collections.abc import Callable
from typing import cast as typing_cast
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
...@@ -21,8 +22,13 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: ...@@ -21,8 +22,13 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
returns an info code with diagnostic information. returns an info code with diagnostic information.
""" """
(getrf,) = linalg.get_lapack_funcs("getrf", (A,)) funcs = linalg.get_lapack_funcs("getrf", (A,))
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a) assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
getrf = funcs[0]
A_copy, ipiv, info = typing_cast(
tuple[np.ndarray, np.ndarray, int], getrf(A, overwrite_a=overwrite_a)
)
return A_copy, ipiv, info return A_copy, ipiv, info
......
from typing import Literal
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
...@@ -13,7 +15,13 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import ( ...@@ -13,7 +15,13 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int): def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A.""" """LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
(geqrf,) = get_lapack_funcs(("geqrf",), (A,)) # (geqrf,) = typing_cast(
# list[Callable[..., np.ndarray]], get_lapack_funcs(("geqrf",), (A,))
# )
funcs = get_lapack_funcs(("geqrf",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
geqrf = funcs[0]
return geqrf(A, overwrite_a=overwrite_a, lwork=lwork) return geqrf(A, overwrite_a=overwrite_a, lwork=lwork)
...@@ -61,7 +69,10 @@ def xgeqrf_impl(A, overwrite_a, lwork): ...@@ -61,7 +69,10 @@ def xgeqrf_impl(A, overwrite_a, lwork):
def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int): def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A.""" """LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
(geqp3,) = get_lapack_funcs(("geqp3",), (A,)) funcs = get_lapack_funcs(("geqp3",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
geqp3 = funcs[0]
return geqp3(A, overwrite_a=overwrite_a, lwork=lwork) return geqp3(A, overwrite_a=overwrite_a, lwork=lwork)
...@@ -111,7 +122,10 @@ def xgeqp3_impl(A, overwrite_a, lwork): ...@@ -111,7 +122,10 @@ def xgeqp3_impl(A, overwrite_a, lwork):
def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int): def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types).""" """LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
(orgqr,) = get_lapack_funcs(("orgqr",), (A,)) funcs = get_lapack_funcs(("orgqr",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
orgqr = funcs[0]
return orgqr(A, tau, overwrite_a=overwrite_a, lwork=lwork) return orgqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
...@@ -160,7 +174,10 @@ def xorgqr_impl(A, tau, overwrite_a, lwork): ...@@ -160,7 +174,10 @@ def xorgqr_impl(A, tau, overwrite_a, lwork):
def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int): def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types).""" """LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
(ungqr,) = get_lapack_funcs(("ungqr",), (A,)) funcs = get_lapack_funcs(("ungqr",), (A,))
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
ungqr = funcs[0]
return ungqr(A, tau, overwrite_a=overwrite_a, lwork=lwork) return ungqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
...@@ -209,8 +226,8 @@ def xungqr_impl(A, tau, overwrite_a, lwork): ...@@ -209,8 +226,8 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
def _qr_full_pivot( def _qr_full_pivot(
x: np.ndarray, x: np.ndarray,
mode: str = "full", mode: Literal["full", "economic"] = "full",
pivoting: bool = True, pivoting: Literal[True] = True,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False, check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
...@@ -234,8 +251,8 @@ def _qr_full_pivot( ...@@ -234,8 +251,8 @@ def _qr_full_pivot(
def _qr_full_no_pivot( def _qr_full_no_pivot(
x: np.ndarray, x: np.ndarray,
mode: str = "full", mode: Literal["full", "economic"] = "full",
pivoting: bool = False, pivoting: Literal[False] = False,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False, check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
...@@ -258,8 +275,8 @@ def _qr_full_no_pivot( ...@@ -258,8 +275,8 @@ def _qr_full_no_pivot(
def _qr_r_pivot( def _qr_r_pivot(
x: np.ndarray, x: np.ndarray,
mode: str = "r", mode: Literal["r", "raw"] = "r",
pivoting: bool = True, pivoting: Literal[True] = True,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False, check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
...@@ -282,8 +299,8 @@ def _qr_r_pivot( ...@@ -282,8 +299,8 @@ def _qr_r_pivot(
def _qr_r_no_pivot( def _qr_r_no_pivot(
x: np.ndarray, x: np.ndarray,
mode: str = "r", mode: Literal["r", "raw"] = "r",
pivoting: bool = False, pivoting: Literal[False] = False,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False, check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
...@@ -306,8 +323,8 @@ def _qr_r_no_pivot( ...@@ -306,8 +323,8 @@ def _qr_r_no_pivot(
def _qr_raw_no_pivot( def _qr_raw_no_pivot(
x: np.ndarray, x: np.ndarray,
mode: str = "raw", mode: Literal["raw"] = "raw",
pivoting: bool = False, pivoting: Literal[False] = False,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False, check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
...@@ -332,8 +349,8 @@ def _qr_raw_no_pivot( ...@@ -332,8 +349,8 @@ def _qr_raw_no_pivot(
def _qr_raw_pivot( def _qr_raw_pivot(
x: np.ndarray, x: np.ndarray,
mode: str = "raw", mode: Literal["raw"] = "raw",
pivoting: bool = True, pivoting: Literal[True] = True,
overwrite_a: bool = False, overwrite_a: bool = False,
check_finite: bool = False, check_finite: bool = False,
lwork: int | None = None, lwork: int | None = None,
......
from collections.abc import Callable from collections.abc import Callable
from typing import Literal, TypeAlias
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
...@@ -20,8 +21,15 @@ from pytensor.link.numba.dispatch.linalg.utils import ( ...@@ -20,8 +21,15 @@ from pytensor.link.numba.dispatch.linalg.utils import (
) )
_Trans: TypeAlias = Literal[0, 1, 2]
def _getrs( def _getrs(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool LU: np.ndarray,
B: np.ndarray,
IPIV: np.ndarray,
trans: _Trans | bool, # mypy does not realize that `bool <: Literal[0, 1]`
overwrite_b: bool,
) -> tuple[np.ndarray, int]: ) -> tuple[np.ndarray, int]:
""" """
Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve. Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve.
...@@ -31,8 +39,10 @@ def _getrs( ...@@ -31,8 +39,10 @@ def _getrs(
@overload(_getrs) @overload(_getrs)
def getrs_impl( def getrs_impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: _Trans, overwrite_b: bool
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]: ) -> Callable[
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs") _check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs") _check_scipy_linalg_matrix(B, "getrs")
...@@ -41,7 +51,11 @@ def getrs_impl( ...@@ -41,7 +51,11 @@ def getrs_impl(
numba_getrs = _LAPACK().numba_xgetrs(dtype) numba_getrs = _LAPACK().numba_xgetrs(dtype)
def impl( def impl(
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool LU: np.ndarray,
B: np.ndarray,
IPIV: np.ndarray,
trans: _Trans,
overwrite_b: bool,
) -> tuple[np.ndarray, int]: ) -> tuple[np.ndarray, int]:
_N = np.int32(LU.shape[-1]) _N = np.int32(LU.shape[-1])
_solve_check_input_shapes(LU, B) _solve_check_input_shapes(LU, B)
...@@ -89,7 +103,7 @@ def getrs_impl( ...@@ -89,7 +103,7 @@ def getrs_impl(
def _lu_solve( def _lu_solve(
lu_and_piv: tuple[np.ndarray, np.ndarray], lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray, b: np.ndarray,
trans: int, trans: _Trans,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool, check_finite: bool,
): ):
...@@ -105,10 +119,10 @@ def _lu_solve( ...@@ -105,10 +119,10 @@ def _lu_solve(
def lu_solve_impl( def lu_solve_impl(
lu_and_piv: tuple[np.ndarray, np.ndarray], lu_and_piv: tuple[np.ndarray, np.ndarray],
b: np.ndarray, b: np.ndarray,
trans: int, trans: _Trans,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool, check_finite: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, bool, bool, bool], np.ndarray]: ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
ensure_lapack() ensure_lapack()
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve") _check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve")
_check_scipy_linalg_matrix(b, "lu_solve") _check_scipy_linalg_matrix(b, "lu_solve")
...@@ -117,7 +131,7 @@ def lu_solve_impl( ...@@ -117,7 +131,7 @@ def lu_solve_impl(
lu: np.ndarray, lu: np.ndarray,
piv: np.ndarray, piv: np.ndarray,
b: np.ndarray, b: np.ndarray,
trans: int, trans: _Trans,
overwrite_b: bool, overwrite_b: bool,
check_finite: bool, check_finite: bool,
) -> np.ndarray: ) -> np.ndarray:
......
...@@ -6,15 +6,25 @@ import logging ...@@ -6,15 +6,25 @@ import logging
import sys import sys
import warnings import warnings
from math import gcd from math import gcd
from typing import TYPE_CHECKING
import numpy as np import numpy as np
from numpy.exceptions import ComplexWarning from numpy.exceptions import ComplexWarning
try: if TYPE_CHECKING:
# https://github.com/scipy/scipy-stubs/issues/851
from scipy.signal._signaltools import ( # type: ignore[attr-defined]
_bvalfromboundary,
_valfrommode,
convolve,
)
from scipy.signal._sigtools import _convolve2d
else:
try:
from scipy.signal.signaltools import _bvalfromboundary, _valfrommode, convolve from scipy.signal.signaltools import _bvalfromboundary, _valfrommode, convolve
from scipy.signal.sigtools import _convolve2d from scipy.signal.sigtools import _convolve2d
except ImportError: except ImportError:
from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve
from scipy.signal._sigtools import _convolve2d from scipy.signal._sigtools import _convolve2d
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论