提交 02910f82 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add some missing type annotations to LocalOptimizer and FromFunctionLocalOptimizer

上级 1ddf666e
...@@ -16,6 +16,7 @@ import warnings ...@@ -16,6 +16,7 @@ import warnings
from collections import OrderedDict, UserList, defaultdict, deque from collections import OrderedDict, UserList, defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from functools import reduce from functools import reduce
from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -32,7 +33,7 @@ from aesara.graph.basic import ( ...@@ -32,7 +33,7 @@ from aesara.graph.basic import (
nodes_constructed, nodes_constructed,
) )
from aesara.graph.features import Feature, NodeFinder from aesara.graph.features import Feature, NodeFinder
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.utils import AssocList from aesara.graph.utils import AssocList
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
...@@ -1025,7 +1026,9 @@ class LocalOptimizer(abc.ABC): ...@@ -1025,7 +1026,9 @@ class LocalOptimizer(abc.ABC):
return None return None
@abc.abstractmethod @abc.abstractmethod
def transform(self, fgraph, node, *args, **kwargs): def transform(
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
) -> Union[bool, List[Variable], Dict[Variable, Variable]]:
r"""Transform a subgraph whose output is `node`. r"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of the Subclasses should implement this function so that it returns one of the
...@@ -1198,7 +1201,11 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -1198,7 +1201,11 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
print(f"{' ' * level}{self.transform} id={id(self)}", file=stream) print(f"{' ' * level}{self.transform} id={id(self)}", file=stream)
def local_optimizer(tracks, inplace=False, requirements=()): def local_optimizer(
tracks: Optional[List[Union[Op, type]]],
inplace: bool = False,
requirements: Optional[Tuple[type, ...]] = (),
):
def decorator(f): def decorator(f):
if tracks is not None: if tracks is not None:
if len(tracks) == 0: if len(tracks) == 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论