提交 0ba5c0a6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix type hints in aesara.graph.opt

上级 c736927b
...@@ -17,7 +17,9 @@ from collections import UserList, defaultdict, deque ...@@ -17,7 +17,9 @@ from collections import UserList, defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from functools import _compose_mro, partial, reduce # type: ignore from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain from itertools import chain
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing_extensions import Literal
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -41,6 +43,17 @@ from aesara.utils import flatten ...@@ -41,6 +43,17 @@ from aesara.utils import flatten
_logger = logging.getLogger("aesara.graph.opt") _logger = logging.getLogger("aesara.graph.opt")
FailureCallbackType = Callable[
[
Exception,
"NavigatorOptimizer",
List[Tuple[Variable, None]],
"LocalOptimizer",
Apply,
],
None,
]
class LocalMetaOptimizerSkipAssertionError(AssertionError): class LocalMetaOptimizerSkipAssertionError(AssertionError):
"""This is an AssertionError, but instead of having the """This is an AssertionError, but instead of having the
...@@ -1770,7 +1783,12 @@ class NavigatorOptimizer(GlobalOptimizer): ...@@ -1770,7 +1783,12 @@ class NavigatorOptimizer(GlobalOptimizer):
def warn_ignore(exc, nav, repl_pairs, local_opt, node): def warn_ignore(exc, nav, repl_pairs, local_opt, node):
"""A failure callback that ignores all errors.""" """A failure callback that ignores all errors."""
def __init__(self, local_opt, ignore_newtrees="auto", failure_callback=None): def __init__(
self,
local_opt: LocalOptimizer,
ignore_newtrees: Literal[True, False, "auto"],
failure_callback: Optional[FailureCallbackType] = None,
):
self.local_opt = local_opt self.local_opt = local_opt
if ignore_newtrees == "auto": if ignore_newtrees == "auto":
self.ignore_newtrees = not getattr(local_opt, "reentrant", True) self.ignore_newtrees = not getattr(local_opt, "reentrant", True)
...@@ -1934,7 +1952,11 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1934,7 +1952,11 @@ class TopoOptimizer(NavigatorOptimizer):
"""An optimizer that applies a single `LocalOptimizer` to each node in topological order (or reverse).""" """An optimizer that applies a single `LocalOptimizer` to each node in topological order (or reverse)."""
def __init__( def __init__(
self, local_opt, order="in_to_out", ignore_newtrees=False, failure_callback=None self,
local_opt: LocalOptimizer,
order: Literal["out_to_in", "in_to_out"] = "in_to_out",
ignore_newtrees: bool = False,
failure_callback: Optional[FailureCallbackType] = None,
): ):
if order not in ("out_to_in", "in_to_out"): if order not in ("out_to_in", "in_to_out"):
raise ValueError("order must be 'out_to_in' or 'in_to_out'") raise ValueError("order must be 'out_to_in' or 'in_to_out'")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论