提交 2fbc3bb6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.link.basic

上级 15c55a3d
......@@ -8,7 +8,6 @@ from typing import (
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
......@@ -24,10 +23,12 @@ from aesara.utils import difference
if TYPE_CHECKING:
from aesara.compile.profiling import ProfileStats
from aesara.tensor.var import TensorVariable
StorageMapType = Dict[Variable, List[Optional[Union[ndarray, slice]]]]
OutputStorageType = List[Optional[List[Any]]]
InputStorageType = OutputStorageType
ThunkType = Tuple[Callable[[], None], List["Container"], List["Container"]]
......@@ -165,7 +166,7 @@ class Linker(ABC):
self,
*,
allow_gc: Optional[bool] = None,
scheduler: Callable[[FunctionGraph], List[Apply]] = None,
scheduler: Optional[Callable[[FunctionGraph], List[Apply]]] = None,
) -> None:
self._allow_gc = allow_gc
self._scheduler = scheduler
......@@ -189,7 +190,9 @@ class Linker(ABC):
return new
@abstractmethod
def make_thunk(self, **kwargs) -> ThunkType:
def make_thunk(
self, **kwargs
) -> Tuple[Callable, InputStorageType, OutputStorageType]:
"""
This function must return a triplet (function, input_variables,
output_variables) where function is a thunk that operates on the
......@@ -239,26 +242,37 @@ class LocalLinker(Linker):
def make_thunk(
self,
input_storage: Optional[Any] = None,
input_storage: Optional[InputStorageType] = None,
output_storage: Optional[OutputStorageType] = None,
storage_map: Optional[StorageMapType] = None,
**kwargs,
) -> ThunkType:
) -> Tuple[Callable[[], None], InputStorageType, OutputStorageType]:
return self.make_all(
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)[:3]
def make_all(self, input_storage, output_storage, storage_map):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
# 1. function to run the program
# 2. input storage
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
def make_all(
self,
input_storage: Optional[InputStorageType],
output_storage: Optional[OutputStorageType],
storage_map: Optional[StorageMapType],
) -> Tuple[
Callable[[], None],
InputStorageType,
OutputStorageType,
List[ThunkType],
List[Apply],
]:
"""
This function should return a tuple of 5 things
1. function to run the program
2. input storage
3. output storage
4. thunks: list of nodes' functions in the order they will be run by the function in (1)
5. order: list of nodes, in the order they will be run by the function in (1)
"""
raise NotImplementedError(
f"make_all method of {type(self)} is not implemented."
)
......@@ -272,7 +286,7 @@ class PerformLinker(LocalLinker):
"""
def __init__(
self, allow_gc: Optional[bool] = None, schedule: Callable = None
self, allow_gc: Optional[bool] = None, schedule: Optional[Callable] = None
) -> None:
if allow_gc is None:
allow_gc = config.allow_gc
......@@ -283,15 +297,15 @@ class PerformLinker(LocalLinker):
self,
fgraph: FunctionGraph,
no_recycling: Optional[Sequence[Variable]] = None,
profile: Optional[bool] = None,
profile: Optional[Union[bool, "ProfileStats"]] = None,
) -> "PerformLinker":
"""
"""Associate a `FunctionGraph` with this `Linker`.
Parameters
----------
fgraph
A :py:class:`aesara.link.basic.PerformLinker` instance can have accepted
one :py:class:`aesara.graph.fg.FunctionGraph` instance at a time.
A `PerformLinker` instance can have accepted one `FunctionGraph`
instance at a time.
no_recycling
WRITEME
......@@ -309,32 +323,10 @@ class PerformLinker(LocalLinker):
def make_all(
self,
input_storage: Optional[Any] = None,
output_storage: Optional[OutputStorageType] = None,
storage_map: Optional[StorageMapType] = None,
) -> Union[
Tuple[Callable, List[Container], List[Container], List[Callable], List[Apply]],
Tuple[Callable, List[Container], List[Container], List[Any], List[Any]],
]:
"""
Returns Function to run all nodes, list of input containers, list of outputs
Parameters
----------
input_storage : optional, list
list of storages corresponding to fgraph.inputs
output_storage
list of storages corresponding to fgraph.outputs
storage_map : iterable
Returns
-------
object
Function to run all nodes, list of input containers, list of output
containers, list of thunks (for all programs), list of nodes
(for all programs).
"""
input_storage=None,
output_storage=None,
storage_map=None,
):
fgraph: Any = self.fgraph
order = self.schedule(fgraph)
no_recycling = self.no_recycling
......@@ -421,7 +413,8 @@ class WrapLinker(Linker):
Parameters
----------
linkers : list of L{LocalLinker} subclasses, whose make_all() method returns
linkers
List of L{LocalLinker} subclasses, whose make_all() method returns
thunks in the same order.
For each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's
......@@ -447,7 +440,7 @@ class WrapLinker(Linker):
linkers: Sequence[PerformLinker],
wrapper: Callable,
) -> None:
self.fgraph = None
self.fgraph: Optional[FunctionGraph] = None
self.linkers = linkers
self.wrapper = wrapper
......@@ -480,8 +473,8 @@ class WrapLinker(Linker):
def accept(
self,
fgraph: FunctionGraph,
no_recycling: Optional[Union[Set["TensorVariable"], List]] = None,
profile: None = None,
no_recycling: Optional[Sequence["TensorVariable"]] = None,
profile: Optional[Union[bool, "ProfileStats"]] = None,
) -> "WrapLinker":
"""
......@@ -514,7 +507,7 @@ class WrapLinker(Linker):
) -> None:
pass
def make_thunk(self, **kwargs) -> ThunkType:
def make_thunk(self, **kwargs):
no_recycling = self.no_recycling
make_all = [self.linkers[0].make_all(**kwargs)]
......
......@@ -10,7 +10,7 @@ from operator import itemgetter
from tempfile import NamedTemporaryFile
from textwrap import indent
from types import FunctionType
from typing import Any, Callable, Dict, Iterable, List, NoReturn, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
......@@ -130,7 +130,7 @@ def streamline(
post_thunk_old_storage=None,
no_recycling=None,
nice_errors=True,
) -> Callable[[], NoReturn]:
) -> Callable[[], None]:
"""
WRITEME
......
......@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.basic]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.utils]
ignore_errors = True
check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论