提交 ff802130 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.opt

上级 8b86e270
......@@ -15,10 +15,12 @@ import traceback
import warnings
from collections import UserList, defaultdict, deque
from collections.abc import Iterable
from functools import partial, reduce
from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing_extensions import TypeAlias
import aesara
from aesara.configdefaults import config
from aesara.graph import destroyhandler as dh
......@@ -1122,9 +1124,9 @@ class LocalOptTracker:
r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance."""
def __init__(self):
self.tracked_instances = {}
self.tracked_types = {}
self.untracked_opts = []
self.tracked_instances: Dict[Op, List[LocalOptimizer]] = {}
self.tracked_types: Dict[TypeAlias, List[LocalOptimizer]] = {}
self.untracked_opts: List[LocalOptimizer] = []
def add_tracker(self, rw: LocalOptimizer):
"""Add a `LocalOptimizer` to be keyed by its `LocalOptimizer.tracks` or applied generally."""
......@@ -1139,12 +1141,12 @@ class LocalOptTracker:
else:
self.tracked_instances.setdefault(c, []).append(rw)
def _find_impl(self, cls):
def _find_impl(self, cls) -> List[LocalOptimizer]:
r"""Returns the `LocalOptimizer`\s that apply to `cls` based on inheritance.
This based on `functools._find_impl`.
"""
mro = functools._compose_mro(cls, self.tracked_types.keys())
mro = _compose_mro(cls, self.tracked_types.keys())
matches = []
for t in mro:
match = self.tracked_types.get(t, None)
......@@ -1185,7 +1187,7 @@ class LocalOptGroup(LocalOptimizer):
def __init__(
self,
*optimizers: Sequence[Rewriter],
*optimizers: Rewriter,
apply_all_opts: bool = False,
profile: bool = False,
):
......@@ -1205,9 +1207,6 @@ class LocalOptGroup(LocalOptimizer):
"""
super().__init__()
if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0])
self.opts: Sequence[Rewriter] = optimizers
assert isinstance(self.opts, tuple)
......
......@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.opt]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.optdb]
ignore_errors = True
check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论