提交 6cf729b9 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Fix mypy in tensor/slinalg.py

上级 7d54c5e4
......@@ -2,7 +2,7 @@ import logging
import typing
import warnings
from functools import reduce
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, cast
import numpy as np
import scipy.linalg
......@@ -141,7 +141,7 @@ def cholesky(x, lower=True, on_error="raise", check_finite=False):
class SolveBase(Op):
"""Base class for `scipy.linalg` matrix equation solvers."""
__props__ = (
__props__: tuple[str, ...] = (
"lower",
"check_finite",
"b_ndim",
......@@ -352,7 +352,7 @@ def solve_triangular(
This will influence how batched dimensions are interpreted.
"""
b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
ret = Blockwise(
SolveTriangular(
lower=lower,
trans=trans,
......@@ -361,6 +361,7 @@ def solve_triangular(
b_ndim=b_ndim,
)
)(a, b)
return cast(TensorVariable, ret)
class Solve(SolveBase):
......@@ -714,9 +715,7 @@ class BilinearSolveDiscreteLyapunov(Op):
_solve_continuous_lyapunov = SolveContinuousLyapunov()
_solve_bilinear_direct_lyapunov = typing.cast(
typing.Callable, BilinearSolveDiscreteLyapunov()
)
_solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov())
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
......@@ -729,7 +728,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV
AA = kron(A_, A_)
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
return typing.cast(TensorVariable, reshape(X, Q_.shape))
return cast(TensorVariable, reshape(X, Q_.shape))
def solve_discrete_lyapunov(
......@@ -765,7 +764,7 @@ def solve_discrete_lyapunov(
if method == "direct":
return _direct_solve_discrete_lyapunov(A, Q)
if method == "bilinear":
return typing.cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
......@@ -785,7 +784,7 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
"""
return typing.cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
class SolveDiscreteARE(pt.Op):
......@@ -866,9 +865,7 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
Square matrix of shape M x M, representing the solution to the DARE
"""
return typing.cast(
TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R)
)
return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R))
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
......
......@@ -24,7 +24,6 @@ pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py
pytensor/tensor/slinalg.py
pytensor/tensor/type.py
pytensor/tensor/type_other.py
pytensor/tensor/variable.py
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论