提交 02910f82 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add some missing type annotations to LocalOptimizer and FromFunctionLocalOptimizer

上级 1ddf666e
......@@ -16,6 +16,7 @@ import warnings
from collections import OrderedDict, UserList, defaultdict, deque
from collections.abc import Iterable
from functools import reduce
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -32,7 +33,7 @@ from aesara.graph.basic import (
nodes_constructed,
)
from aesara.graph.features import Feature, NodeFinder
from aesara.graph.fg import InconsistencyError
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import Op
from aesara.graph.utils import AssocList
from aesara.misc.ordered_set import OrderedSet
......@@ -1025,7 +1026,9 @@ class LocalOptimizer(abc.ABC):
return None
@abc.abstractmethod
def transform(self, fgraph, node, *args, **kwargs):
def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
) -> Union[bool, List[Variable], Dict[Variable, Variable]]:
r"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of the
......@@ -1198,7 +1201,11 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
print(f"{' ' * level}{self.transform} id={id(self)}", file=stream)
def local_optimizer(tracks, inplace=False, requirements=()):
def local_optimizer(
tracks: Optional[List[Union[Op, type]]],
inplace: bool = False,
requirements: Optional[Tuple[type, ...]] = (),
):
def decorator(f):
if tracks is not None:
if len(tracks) == 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论