提交 6cf729b9 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Fix mypy in tensor/slinalg.py

上级 7d54c5e4
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import typing import typing
import warnings import warnings
from functools import reduce from functools import reduce
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal, cast
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg
...@@ -141,7 +141,7 @@ def cholesky(x, lower=True, on_error="raise", check_finite=False): ...@@ -141,7 +141,7 @@ def cholesky(x, lower=True, on_error="raise", check_finite=False):
class SolveBase(Op): class SolveBase(Op):
"""Base class for `scipy.linalg` matrix equation solvers.""" """Base class for `scipy.linalg` matrix equation solvers."""
__props__ = ( __props__: tuple[str, ...] = (
"lower", "lower",
"check_finite", "check_finite",
"b_ndim", "b_ndim",
...@@ -352,7 +352,7 @@ def solve_triangular( ...@@ -352,7 +352,7 @@ def solve_triangular(
This will influence how batched dimensions are interpreted. This will influence how batched dimensions are interpreted.
""" """
b_ndim = _default_b_ndim(b, b_ndim) b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise( ret = Blockwise(
SolveTriangular( SolveTriangular(
lower=lower, lower=lower,
trans=trans, trans=trans,
...@@ -361,6 +361,7 @@ def solve_triangular( ...@@ -361,6 +361,7 @@ def solve_triangular(
b_ndim=b_ndim, b_ndim=b_ndim,
) )
)(a, b) )(a, b)
return cast(TensorVariable, ret)
class Solve(SolveBase): class Solve(SolveBase):
...@@ -714,9 +715,7 @@ class BilinearSolveDiscreteLyapunov(Op): ...@@ -714,9 +715,7 @@ class BilinearSolveDiscreteLyapunov(Op):
_solve_continuous_lyapunov = SolveContinuousLyapunov() _solve_continuous_lyapunov = SolveContinuousLyapunov()
_solve_bilinear_direct_lyapunov = typing.cast( _solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov())
typing.Callable, BilinearSolveDiscreteLyapunov()
)
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
...@@ -729,7 +728,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV ...@@ -729,7 +728,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV
AA = kron(A_, A_) AA = kron(A_, A_)
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel()) X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
return typing.cast(TensorVariable, reshape(X, Q_.shape)) return cast(TensorVariable, reshape(X, Q_.shape))
def solve_discrete_lyapunov( def solve_discrete_lyapunov(
...@@ -765,7 +764,7 @@ def solve_discrete_lyapunov( ...@@ -765,7 +764,7 @@ def solve_discrete_lyapunov(
if method == "direct": if method == "direct":
return _direct_solve_discrete_lyapunov(A, Q) return _direct_solve_discrete_lyapunov(A, Q)
if method == "bilinear": if method == "bilinear":
return typing.cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q)) return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
...@@ -785,7 +784,7 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl ...@@ -785,7 +784,7 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
""" """
return typing.cast(TensorVariable, _solve_continuous_lyapunov(A, Q)) return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
class SolveDiscreteARE(pt.Op): class SolveDiscreteARE(pt.Op):
...@@ -866,9 +865,7 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: ...@@ -866,9 +865,7 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
Square matrix of shape M x M, representing the solution to the DARE Square matrix of shape M x M, representing the solution to the DARE
""" """
return typing.cast( return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R))
TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R)
)
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
......
...@@ -24,7 +24,6 @@ pytensor/tensor/random/basic.py ...@@ -24,7 +24,6 @@ pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py pytensor/tensor/rewriting/basic.py
pytensor/tensor/slinalg.py
pytensor/tensor/type.py pytensor/tensor/type.py
pytensor/tensor/type_other.py pytensor/tensor/type_other.py
pytensor/tensor/variable.py pytensor/tensor/variable.py
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论