提交 51464cd1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.optdb

上级 ff802130
......@@ -54,6 +54,8 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError):
class Rewriter(abc.ABC):
"""Abstract base class for graph/term rewriters."""
name: Optional[str] = None
@abc.abstractmethod
def add_requirements(self, fgraph: FunctionGraph):
r"""Add `Feature`\s and other requirements to a `FunctionGraph`."""
......
......@@ -3,7 +3,7 @@ import math
import sys
from functools import cmp_to_key
from io import StringIO
from typing import Dict, Optional, Sequence, Union
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union
from aesara.configdefaults import config
from aesara.graph import opt as aesara_opt
......@@ -186,12 +186,16 @@ class OptimizationQuery:
def __init__(
self,
include: Sequence[str],
require: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
include: Iterable[str],
require: Optional[Union[OrderedSet, Sequence[str]]] = None,
exclude: Optional[Union[OrderedSet, Sequence[str]]] = None,
subquery: Optional[Dict[str, "OptimizationQuery"]] = None,
position_cutoff: float = math.inf,
extra_optimizations: Optional[Sequence[OptimizersType]] = None,
extra_optimizations: Optional[
Sequence[
Tuple[Union["OptimizationQuery", OptimizersType], Union[int, float]]
]
] = None,
):
"""
......@@ -219,18 +223,14 @@ class OptimizationQuery:
"""
self.include = OrderedSet(include)
self.require = require or OrderedSet()
self.exclude = exclude or OrderedSet()
self.require = OrderedSet(require) if require else OrderedSet()
self.exclude = OrderedSet(exclude) if exclude else OrderedSet()
self.subquery = subquery or {}
self.position_cutoff = position_cutoff
self.name: Optional[str] = None
if extra_optimizations is None:
extra_optimizations = []
self.extra_optimizations = extra_optimizations
if isinstance(self.require, (list, tuple)):
self.require = OrderedSet(self.require)
if isinstance(self.exclude, (list, tuple)):
self.exclude = OrderedSet(self.exclude)
self.extra_optimizations = list(extra_optimizations)
def __str__(self):
return (
......@@ -279,7 +279,9 @@ class OptimizationQuery:
self.extra_optimizations,
)
def register(self, *optimizations: Sequence[OptimizersType]) -> "OptimizationQuery":
def register(
self, *optimizations: Tuple["OptimizationQuery", Union[int, float]]
) -> "OptimizationQuery":
"""Include the given optimizations."""
return OptimizationQuery(
self.include,
......@@ -394,10 +396,9 @@ class SequenceDB(OptimizationDatabase):
self.__position__ = {}
self.failure_callback = failure_callback
def register(
self, name, obj, *tags, position: Union[str, int, float] = "last", **kwargs
):
def register(self, name, obj, *tags, **kwargs):
super().register(name, obj, *tags, **kwargs)
position = kwargs.pop("position", "last")
if position == "last":
if len(self.__position__) == 0:
self.__position__[name] = 0
......
......@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.optdb]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.opt_utils]
ignore_errors = True
check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论