提交 2fbc3bb6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.link.basic

上级 15c55a3d
...@@ -8,7 +8,6 @@ from typing import ( ...@@ -8,7 +8,6 @@ from typing import (
List, List,
Optional, Optional,
Sequence, Sequence,
Set,
Tuple, Tuple,
Union, Union,
) )
...@@ -24,10 +23,12 @@ from aesara.utils import difference ...@@ -24,10 +23,12 @@ from aesara.utils import difference
if TYPE_CHECKING: if TYPE_CHECKING:
from aesara.compile.profiling import ProfileStats
from aesara.tensor.var import TensorVariable from aesara.tensor.var import TensorVariable
StorageMapType = Dict[Variable, List[Optional[Union[ndarray, slice]]]] StorageMapType = Dict[Variable, List[Optional[Union[ndarray, slice]]]]
OutputStorageType = List[Optional[List[Any]]] OutputStorageType = List[Optional[List[Any]]]
InputStorageType = OutputStorageType
ThunkType = Tuple[Callable[[], None], List["Container"], List["Container"]] ThunkType = Tuple[Callable[[], None], List["Container"], List["Container"]]
...@@ -165,7 +166,7 @@ class Linker(ABC): ...@@ -165,7 +166,7 @@ class Linker(ABC):
self, self,
*, *,
allow_gc: Optional[bool] = None, allow_gc: Optional[bool] = None,
scheduler: Callable[[FunctionGraph], List[Apply]] = None, scheduler: Optional[Callable[[FunctionGraph], List[Apply]]] = None,
) -> None: ) -> None:
self._allow_gc = allow_gc self._allow_gc = allow_gc
self._scheduler = scheduler self._scheduler = scheduler
...@@ -189,7 +190,9 @@ class Linker(ABC): ...@@ -189,7 +190,9 @@ class Linker(ABC):
return new return new
@abstractmethod @abstractmethod
def make_thunk(self, **kwargs) -> ThunkType: def make_thunk(
self, **kwargs
) -> Tuple[Callable, InputStorageType, OutputStorageType]:
""" """
This function must return a triplet (function, input_variables, This function must return a triplet (function, input_variables,
output_variables) where function is a thunk that operates on the output_variables) where function is a thunk that operates on the
...@@ -239,26 +242,37 @@ class LocalLinker(Linker): ...@@ -239,26 +242,37 @@ class LocalLinker(Linker):
def make_thunk( def make_thunk(
self, self,
input_storage: Optional[Any] = None, input_storage: Optional[InputStorageType] = None,
output_storage: Optional[OutputStorageType] = None, output_storage: Optional[OutputStorageType] = None,
storage_map: Optional[StorageMapType] = None, storage_map: Optional[StorageMapType] = None,
**kwargs, **kwargs,
) -> ThunkType: ) -> Tuple[Callable[[], None], InputStorageType, OutputStorageType]:
return self.make_all( return self.make_all(
input_storage=input_storage, input_storage=input_storage,
output_storage=output_storage, output_storage=output_storage,
storage_map=storage_map, storage_map=storage_map,
)[:3] )[:3]
def make_all(self, input_storage, output_storage, storage_map): def make_all(
# By convention, subclasses of LocalLinker should implement this function! self,
# input_storage: Optional[InputStorageType],
# This function should return a tuple of 5 things output_storage: Optional[OutputStorageType],
# 1. function to run the program storage_map: Optional[StorageMapType],
# 2. input storage ) -> Tuple[
# 3. output storage Callable[[], None],
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1) InputStorageType,
# 5. order: list of nodes, in the order they will be run by the function in (1) OutputStorageType,
List[ThunkType],
List[Apply],
]:
"""
This function should return a tuple of 5 things
1. function to run the program
2. input storage
3. output storage
4. thunks: list of nodes' functions in the order they will be run by the function in (1)
5. order: list of nodes, in the order they will be run by the function in (1)
"""
raise NotImplementedError( raise NotImplementedError(
f"make_all method of {type(self)} is not implemented." f"make_all method of {type(self)} is not implemented."
) )
...@@ -272,7 +286,7 @@ class PerformLinker(LocalLinker): ...@@ -272,7 +286,7 @@ class PerformLinker(LocalLinker):
""" """
def __init__( def __init__(
self, allow_gc: Optional[bool] = None, schedule: Callable = None self, allow_gc: Optional[bool] = None, schedule: Optional[Callable] = None
) -> None: ) -> None:
if allow_gc is None: if allow_gc is None:
allow_gc = config.allow_gc allow_gc = config.allow_gc
...@@ -283,15 +297,15 @@ class PerformLinker(LocalLinker): ...@@ -283,15 +297,15 @@ class PerformLinker(LocalLinker):
self, self,
fgraph: FunctionGraph, fgraph: FunctionGraph,
no_recycling: Optional[Sequence[Variable]] = None, no_recycling: Optional[Sequence[Variable]] = None,
profile: Optional[bool] = None, profile: Optional[Union[bool, "ProfileStats"]] = None,
) -> "PerformLinker": ) -> "PerformLinker":
""" """Associate a `FunctionGraph` with this `Linker`.
Parameters Parameters
---------- ----------
fgraph fgraph
A :py:class:`aesara.link.basic.PerformLinker` instance can have accepted A `PerformLinker` instance can have accepted one `FunctionGraph`
one :py:class:`aesara.graph.fg.FunctionGraph` instance at a time. instance at a time.
no_recycling no_recycling
WRITEME WRITEME
...@@ -309,32 +323,10 @@ class PerformLinker(LocalLinker): ...@@ -309,32 +323,10 @@ class PerformLinker(LocalLinker):
def make_all( def make_all(
self, self,
input_storage: Optional[Any] = None, input_storage=None,
output_storage: Optional[OutputStorageType] = None, output_storage=None,
storage_map: Optional[StorageMapType] = None, storage_map=None,
) -> Union[ ):
Tuple[Callable, List[Container], List[Container], List[Callable], List[Apply]],
Tuple[Callable, List[Container], List[Container], List[Any], List[Any]],
]:
"""
Returns Function to run all nodes, list of input containers, list of outputs
Parameters
----------
input_storage : optional, list
list of storages corresponding to fgraph.inputs
output_storage
list of storages corresponding to fgraph.outputs
storage_map : iterable
Returns
-------
object
Function to run all nodes, list of input containers, list of output
containers, list of thunks (for all programs), list of nodes
(for all programs).
"""
fgraph: Any = self.fgraph fgraph: Any = self.fgraph
order = self.schedule(fgraph) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
...@@ -421,7 +413,8 @@ class WrapLinker(Linker): ...@@ -421,7 +413,8 @@ class WrapLinker(Linker):
Parameters Parameters
---------- ----------
linkers : list of L{LocalLinker} subclasses, whose make_all() method returns linkers
List of L{LocalLinker} subclasses, whose make_all() method returns
thunks in the same order. thunks in the same order.
For each node in the graph, each linker will provide a For each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's thunk. This class makes it possible to iterate over each linker's
...@@ -447,7 +440,7 @@ class WrapLinker(Linker): ...@@ -447,7 +440,7 @@ class WrapLinker(Linker):
linkers: Sequence[PerformLinker], linkers: Sequence[PerformLinker],
wrapper: Callable, wrapper: Callable,
) -> None: ) -> None:
self.fgraph = None self.fgraph: Optional[FunctionGraph] = None
self.linkers = linkers self.linkers = linkers
self.wrapper = wrapper self.wrapper = wrapper
...@@ -480,8 +473,8 @@ class WrapLinker(Linker): ...@@ -480,8 +473,8 @@ class WrapLinker(Linker):
def accept( def accept(
self, self,
fgraph: FunctionGraph, fgraph: FunctionGraph,
no_recycling: Optional[Union[Set["TensorVariable"], List]] = None, no_recycling: Optional[Sequence["TensorVariable"]] = None,
profile: None = None, profile: Optional[Union[bool, "ProfileStats"]] = None,
) -> "WrapLinker": ) -> "WrapLinker":
""" """
...@@ -514,7 +507,7 @@ class WrapLinker(Linker): ...@@ -514,7 +507,7 @@ class WrapLinker(Linker):
) -> None: ) -> None:
pass pass
def make_thunk(self, **kwargs) -> ThunkType: def make_thunk(self, **kwargs):
no_recycling = self.no_recycling no_recycling = self.no_recycling
make_all = [self.linkers[0].make_all(**kwargs)] make_all = [self.linkers[0].make_all(**kwargs)]
......
...@@ -10,7 +10,7 @@ from operator import itemgetter ...@@ -10,7 +10,7 @@ from operator import itemgetter
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from textwrap import indent from textwrap import indent
from types import FunctionType from types import FunctionType
from typing import Any, Callable, Dict, Iterable, List, NoReturn, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -130,7 +130,7 @@ def streamline( ...@@ -130,7 +130,7 @@ def streamline(
post_thunk_old_storage=None, post_thunk_old_storage=None,
no_recycling=None, no_recycling=None,
nice_errors=True, nice_errors=True,
) -> Callable[[], NoReturn]: ) -> Callable[[], None]:
""" """
WRITEME WRITEME
......
...@@ -115,10 +115,6 @@ check_untyped_defs = False ...@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.link.basic]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.utils] [mypy-aesara.link.utils]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论