提交 ff802130 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.opt

上级 8b86e270
...@@ -15,10 +15,12 @@ import traceback ...@@ -15,10 +15,12 @@ import traceback
import warnings import warnings
from collections import UserList, defaultdict, deque from collections import UserList, defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial, reduce from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain from itertools import chain
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing_extensions import TypeAlias
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import destroyhandler as dh from aesara.graph import destroyhandler as dh
...@@ -1122,9 +1124,9 @@ class LocalOptTracker: ...@@ -1122,9 +1124,9 @@ class LocalOptTracker:
r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance.""" r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance."""
def __init__(self): def __init__(self):
self.tracked_instances = {} self.tracked_instances: Dict[Op, List[LocalOptimizer]] = {}
self.tracked_types = {} self.tracked_types: Dict[TypeAlias, List[LocalOptimizer]] = {}
self.untracked_opts = [] self.untracked_opts: List[LocalOptimizer] = []
def add_tracker(self, rw: LocalOptimizer): def add_tracker(self, rw: LocalOptimizer):
"""Add a `LocalOptimizer` to be keyed by its `LocalOptimizer.tracks` or applied generally.""" """Add a `LocalOptimizer` to be keyed by its `LocalOptimizer.tracks` or applied generally."""
...@@ -1139,12 +1141,12 @@ class LocalOptTracker: ...@@ -1139,12 +1141,12 @@ class LocalOptTracker:
else: else:
self.tracked_instances.setdefault(c, []).append(rw) self.tracked_instances.setdefault(c, []).append(rw)
def _find_impl(self, cls): def _find_impl(self, cls) -> List[LocalOptimizer]:
r"""Returns the `LocalOptimizer`\s that apply to `cls` based on inheritance. r"""Returns the `LocalOptimizer`\s that apply to `cls` based on inheritance.
This based on `functools._find_impl`. This based on `functools._find_impl`.
""" """
mro = functools._compose_mro(cls, self.tracked_types.keys()) mro = _compose_mro(cls, self.tracked_types.keys())
matches = [] matches = []
for t in mro: for t in mro:
match = self.tracked_types.get(t, None) match = self.tracked_types.get(t, None)
...@@ -1185,7 +1187,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1185,7 +1187,7 @@ class LocalOptGroup(LocalOptimizer):
def __init__( def __init__(
self, self,
*optimizers: Sequence[Rewriter], *optimizers: Rewriter,
apply_all_opts: bool = False, apply_all_opts: bool = False,
profile: bool = False, profile: bool = False,
): ):
...@@ -1205,9 +1207,6 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1205,9 +1207,6 @@ class LocalOptGroup(LocalOptimizer):
""" """
super().__init__() super().__init__()
if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0])
self.opts: Sequence[Rewriter] = optimizers self.opts: Sequence[Rewriter] = optimizers
assert isinstance(self.opts, tuple) assert isinstance(self.opts, tuple)
......
...@@ -115,10 +115,6 @@ check_untyped_defs = False ...@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.graph.opt]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.optdb] [mypy-aesara.graph.optdb]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论