提交 51464cd1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.optdb

上级 ff802130
...@@ -54,6 +54,8 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError): ...@@ -54,6 +54,8 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError):
class Rewriter(abc.ABC): class Rewriter(abc.ABC):
"""Abstract base class for graph/term rewriters.""" """Abstract base class for graph/term rewriters."""
name: Optional[str] = None
@abc.abstractmethod @abc.abstractmethod
def add_requirements(self, fgraph: FunctionGraph): def add_requirements(self, fgraph: FunctionGraph):
r"""Add `Feature`\s and other requirements to a `FunctionGraph`.""" r"""Add `Feature`\s and other requirements to a `FunctionGraph`."""
......
...@@ -3,7 +3,7 @@ import math ...@@ -3,7 +3,7 @@ import math
import sys import sys
from functools import cmp_to_key from functools import cmp_to_key
from io import StringIO from io import StringIO
from typing import Dict, Optional, Sequence, Union from typing import Dict, Iterable, Optional, Sequence, Tuple, Union
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import opt as aesara_opt from aesara.graph import opt as aesara_opt
...@@ -186,12 +186,16 @@ class OptimizationQuery: ...@@ -186,12 +186,16 @@ class OptimizationQuery:
def __init__( def __init__(
self, self,
include: Sequence[str], include: Iterable[str],
require: Optional[Sequence[str]] = None, require: Optional[Union[OrderedSet, Sequence[str]]] = None,
exclude: Optional[Sequence[str]] = None, exclude: Optional[Union[OrderedSet, Sequence[str]]] = None,
subquery: Optional[Dict[str, "OptimizationQuery"]] = None, subquery: Optional[Dict[str, "OptimizationQuery"]] = None,
position_cutoff: float = math.inf, position_cutoff: float = math.inf,
extra_optimizations: Optional[Sequence[OptimizersType]] = None, extra_optimizations: Optional[
Sequence[
Tuple[Union["OptimizationQuery", OptimizersType], Union[int, float]]
]
] = None,
): ):
""" """
...@@ -219,18 +223,14 @@ class OptimizationQuery: ...@@ -219,18 +223,14 @@ class OptimizationQuery:
""" """
self.include = OrderedSet(include) self.include = OrderedSet(include)
self.require = require or OrderedSet() self.require = OrderedSet(require) if require else OrderedSet()
self.exclude = exclude or OrderedSet() self.exclude = OrderedSet(exclude) if exclude else OrderedSet()
self.subquery = subquery or {} self.subquery = subquery or {}
self.position_cutoff = position_cutoff self.position_cutoff = position_cutoff
self.name: Optional[str] = None self.name: Optional[str] = None
if extra_optimizations is None: if extra_optimizations is None:
extra_optimizations = [] extra_optimizations = []
self.extra_optimizations = extra_optimizations self.extra_optimizations = list(extra_optimizations)
if isinstance(self.require, (list, tuple)):
self.require = OrderedSet(self.require)
if isinstance(self.exclude, (list, tuple)):
self.exclude = OrderedSet(self.exclude)
def __str__(self): def __str__(self):
return ( return (
...@@ -279,7 +279,9 @@ class OptimizationQuery: ...@@ -279,7 +279,9 @@ class OptimizationQuery:
self.extra_optimizations, self.extra_optimizations,
) )
def register(self, *optimizations: Sequence[OptimizersType]) -> "OptimizationQuery": def register(
self, *optimizations: Tuple["OptimizationQuery", Union[int, float]]
) -> "OptimizationQuery":
"""Include the given optimizations.""" """Include the given optimizations."""
return OptimizationQuery( return OptimizationQuery(
self.include, self.include,
...@@ -394,10 +396,9 @@ class SequenceDB(OptimizationDatabase): ...@@ -394,10 +396,9 @@ class SequenceDB(OptimizationDatabase):
self.__position__ = {} self.__position__ = {}
self.failure_callback = failure_callback self.failure_callback = failure_callback
def register( def register(self, name, obj, *tags, **kwargs):
self, name, obj, *tags, position: Union[str, int, float] = "last", **kwargs
):
super().register(name, obj, *tags, **kwargs) super().register(name, obj, *tags, **kwargs)
position = kwargs.pop("position", "last")
if position == "last": if position == "last":
if len(self.__position__) == 0: if len(self.__position__) == 0:
self.__position__[name] = 0 self.__position__[name] = 0
......
...@@ -115,10 +115,6 @@ check_untyped_defs = False ...@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.graph.optdb]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.opt_utils] [mypy-aesara.graph.opt_utils]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论