提交 7d54c5e4 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Enable mypy's `warn_no_return` lint

上级 ee4d4f71
...@@ -156,17 +156,12 @@ lines-after-imports = 2 ...@@ -156,17 +156,12 @@ lines-after-imports = 2
[tool.mypy] [tool.mypy]
python_version = "3.10" python_version = "3.10"
ignore_missing_imports = true ignore_missing_imports = true
no_implicit_optional = true
check_untyped_defs = false
strict_equality = true strict_equality = true
warn_redundant_casts = true warn_redundant_casts = true
warn_unused_configs = true warn_unused_configs = true
warn_unused_ignores = true warn_unused_ignores = true
warn_return_any = true warn_return_any = true
warn_no_return = false
warn_unreachable = true warn_unreachable = true
show_error_codes = true
allow_redefinition = false
files = ["pytensor", "tests"] files = ["pytensor", "tests"]
plugins = ["numpy.typing.mypy_plugin"] plugins = ["numpy.typing.mypy_plugin"]
......
...@@ -909,6 +909,7 @@ def ancestors( ...@@ -909,6 +909,7 @@ def ancestors(
def expand(r: Variable) -> Iterator[Variable] | None: def expand(r: Variable) -> Iterator[Variable] | None:
if r.owner and (not blockers or r not in blockers): if r.owner and (not blockers or r not in blockers):
return reversed(r.owner.inputs) return reversed(r.owner.inputs)
return None
yield from cast(Generator[Variable, None, None], walk(graphs, expand, False)) yield from cast(Generator[Variable, None, None], walk(graphs, expand, False))
...@@ -1011,6 +1012,7 @@ def vars_between( ...@@ -1011,6 +1012,7 @@ def vars_between(
def expand(r: Variable) -> Iterable[Variable] | None: def expand(r: Variable) -> Iterable[Variable] | None:
if r.owner and r not in ins: if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs) return reversed(r.owner.inputs + r.owner.outputs)
return None
yield from cast(Generator[Variable, None, None], walk(outs, expand)) yield from cast(Generator[Variable, None, None], walk(outs, expand))
...@@ -2039,13 +2041,15 @@ def get_var_by_name( ...@@ -2039,13 +2041,15 @@ def get_var_by_name(
from pytensor.graph.op import HasInnerGraph from pytensor.graph.op import HasInnerGraph
def expand(r) -> list[Variable] | None: def expand(r) -> list[Variable] | None:
if r.owner: if not r.owner:
res = list(r.owner.inputs) return None
res = list(r.owner.inputs)
if isinstance(r.owner.op, HasInnerGraph): if isinstance(r.owner.op, HasInnerGraph):
res.extend(r.owner.op.inner_outputs) res.extend(r.owner.op.inner_outputs)
return res return res
results: tuple[Variable, ...] = () results: tuple[Variable, ...] = ()
for var in walk(graphs, expand, False): for var in walk(graphs, expand, False):
......
...@@ -355,34 +355,37 @@ def local_lift_through_linalg( ...@@ -355,34 +355,37 @@ def local_lift_through_linalg(
""" """
# TODO: Simplify this if we end up Blockwising KroneckerProduct # TODO: Simplify this if we end up Blockwising KroneckerProduct
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv): if not isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
y = node.inputs[0] return None
outer_op = node.op
if y.owner and (
isinstance(y.owner.op, Blockwise)
and isinstance(y.owner.op.core_op, BlockDiagonal)
or isinstance(y.owner.op, KroneckerProduct)
):
input_matrices = y.owner.inputs
if isinstance(outer_op.core_op, MatrixInverse):
outer_f = cast(Callable, inv)
elif isinstance(outer_op.core_op, Cholesky):
outer_f = cast(Callable, cholesky)
elif isinstance(outer_op.core_op, MatrixPinv):
outer_f = cast(Callable, pinv)
else:
raise NotImplementedError # pragma: no cover
inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices] y = node.inputs[0]
outer_op = node.op
if isinstance(y.owner.op, KroneckerProduct): if y.owner and (
return [kron(*inner_matrices)] isinstance(y.owner.op, Blockwise)
elif isinstance(y.owner.op.core_op, BlockDiagonal): and isinstance(y.owner.op.core_op, BlockDiagonal)
return [block_diag(*inner_matrices)] or isinstance(y.owner.op, KroneckerProduct)
else: ):
raise NotImplementedError # pragma: no cover input_matrices = y.owner.inputs
if isinstance(outer_op.core_op, MatrixInverse):
outer_f = cast(Callable, inv)
elif isinstance(outer_op.core_op, Cholesky):
outer_f = cast(Callable, cholesky)
elif isinstance(outer_op.core_op, MatrixPinv):
outer_f = cast(Callable, pinv)
else:
raise NotImplementedError # pragma: no cover
inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices]
if isinstance(y.owner.op, KroneckerProduct):
return [kron(*inner_matrices)]
elif isinstance(y.owner.op.core_op, BlockDiagonal):
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover
return None
def _find_diag_from_eye_mul(potential_mul_input): def _find_diag_from_eye_mul(potential_mul_input):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论