Unverified 提交 3e58e7ce authored 作者: Nicolas Legrand's avatar Nicolas Legrand 提交者: GitHub

Add type hints + docstrings details to theano.link.basics.py (#272)

* Add type hints and docstrings for theano.link.basics.py * Rename theano to aesara in docstrings * Add `StorageMapType` and `OutputStorageType` * Fix `make_thunk` * Remove `from __future__ import annotations`
上级 37444647
import typing
from copy import copy, deepcopy from copy import copy, deepcopy
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
NoReturn,
Optional,
Set,
Tuple,
Union,
)
from numpy import ndarray
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import CType from aesara.graph.type import CType
from aesara.graph.utils import MetaObject
from aesara.link.utils import gc_helper, map_storage, raise_with_op, streamline from aesara.link.utils import gc_helper, map_storage, raise_with_op, streamline
from aesara.utils import deprecated, difference, to_return_values from aesara.utils import deprecated, difference, to_return_values
if TYPE_CHECKING:
from aesara.link.c.basic import OpWiseCLinker
from aesara.link.vm import VMLinker
from aesara.tensor.var import TensorVariable
StorageMapType = Dict[Variable, List[Optional[Union[ndarray, slice]]]]
OutputStorageType = List[Optional[List[Any]]]
ThunkType = Tuple[Callable[[], NoReturn], List["Container"], List["Container"]]
class Container: class Container:
""" """
This class joins a variable with its computed value. This class joins a variable with its computed value.
...@@ -17,14 +42,14 @@ class Container: ...@@ -17,14 +42,14 @@ class Container:
Parameters Parameters
---------- ----------
r : a Variable or a Type r : :py:class:`aerasa.graph.utils.MetaObject`
storage storage
A list of length 1, whose element is the value for `r`. A list of length 1, whose element is the value for `r`.
readonly : bool readonly : bool
True indicates that this should not be setable by Function[r] = val. True indicates that this should not be setable by Function[r] = val.
strict : bool strict : bool
If True, we don't allow type casting. If True, we don't allow type casting.
allow_downcast allow_downcast : bool
If True (and `strict` is False), allow upcasting of type, but not If True (and `strict` is False), allow upcasting of type, but not
downcasting. If False, prevent it. If None (default), allows only downcasting. If False, prevent it. If None (default), allows only
downcasting of float to floatX scalar. downcasting of float to floatX scalar.
...@@ -35,14 +60,14 @@ class Container: ...@@ -35,14 +60,14 @@ class Container:
def __init__( def __init__(
self, self,
r, r: MetaObject,
storage, storage: Any,
*, *,
readonly=False, readonly: bool = False,
strict=False, strict: bool = False,
allow_downcast=None, allow_downcast: Optional[bool] = None,
name=None, name: Optional[str] = None,
): ) -> None:
if not isinstance(storage, list) or not len(storage) >= 1: if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one") raise TypeError("storage must be a list of length at least one")
if isinstance(r, CType): if isinstance(r, CType):
...@@ -60,10 +85,10 @@ class Container: ...@@ -60,10 +85,10 @@ class Container:
self.strict = strict self.strict = strict
self.allow_downcast = allow_downcast self.allow_downcast = allow_downcast
def __get__(self): def __get__(self) -> Any:
return self.storage[0] return self.storage[0]
def __set__(self, value): def __set__(self, value: Any) -> None:
if self.readonly: if self.readonly:
raise Exception(f"Cannot set readonly storage: {self.name}") raise Exception(f"Cannot set readonly storage: {self.name}")
try: try:
...@@ -98,7 +123,7 @@ class Container: ...@@ -98,7 +123,7 @@ class Container:
def __repr__(self): def __repr__(self):
return "<" + repr(self.storage[0]) + ">" return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo): def __deepcopy__(self, memo: Dict[int, Any]) -> "Container":
data_was_in_memo = id(self.storage[0]) in memo data_was_in_memo = id(self.storage[0]) in memo
r = type(self)( r = type(self)(
deepcopy(self.type, memo=memo), deepcopy(self.type, memo=memo),
...@@ -132,41 +157,39 @@ class Linker: ...@@ -132,41 +157,39 @@ class Linker:
allow_gc : optional, bool allow_gc : optional, bool
Configures if garbage collection is enabled. Configures if garbage collection is enabled.
scheduler : callable scheduler : callable
A scheduling function that takes a FunctionGraph and returns a list of Apply nodes. A scheduling function that takes a FunctionGraph and returns
Defaults to the .toposort() method of the FunctionGraph. a list of Apply nodes. Defaults to the .toposort() method of
the FunctionGraph.
""" """
def __init__( def __init__(
self, self,
*, *,
allow_gc: typing.Optional[bool] = None, allow_gc: Optional[bool] = None,
scheduler: typing.Callable[[FunctionGraph], typing.List[Apply]] = None, scheduler: Callable[[FunctionGraph], List[Apply]] = None,
): ) -> None:
self._allow_gc = allow_gc self._allow_gc = allow_gc
self._scheduler = scheduler self._scheduler = scheduler
super().__init__() super().__init__()
@property @property
def allow_gc(self) -> typing.Optional[bool]: def allow_gc(self) -> Optional[bool]:
"""Determines if the linker may allow garbage collection. """Determines if the linker may allow garbage collection.
None means undefined. Returns
-------
_allow_gc : optional, bool
None means undefined.
""" """
return self._allow_gc return self._allow_gc
def clone(self, allow_gc: typing.Optional[bool] = None): def clone(self, allow_gc: Optional[bool] = None) -> "Linker":
new = copy(self) new = copy(self)
if allow_gc is not None: if allow_gc is not None:
new._allow_gc = allow_gc new._allow_gc = allow_gc
return new return new
def make_thunk( def make_thunk(self, **kwargs) -> ThunkType:
self,
) -> typing.Tuple[
typing.Callable[[], typing.NoReturn],
typing.List[Container],
typing.List[Container],
]:
""" """
This function must return a triplet (function, input_variables, This function must return a triplet (function, input_variables,
output_variables) where function is a thunk that operates on the output_variables) where function is a thunk that operates on the
...@@ -193,7 +216,7 @@ class Linker: ...@@ -193,7 +216,7 @@ class Linker:
) )
@deprecated("Marked for deletion. Only tests use it.") @deprecated("Marked for deletion. Only tests use it.")
def make_function(self, unpack_single=True, **kwargs): def make_function(self, unpack_single: bool = True, **kwargs) -> Callable:
""" """
Returns a function that takes values corresponding to the inputs of the Returns a function that takes values corresponding to the inputs of the
fgraph used by this L{Linker} and returns values corresponding the the fgraph used by this L{Linker} and returns values corresponding the the
...@@ -201,6 +224,13 @@ class Linker: ...@@ -201,6 +224,13 @@ class Linker:
operate in the same storage the fgraph uses, else independent storage operate in the same storage the fgraph uses, else independent storage
will be allocated for the function. will be allocated for the function.
Parameters
----------
unpack_single : bool
If `unpack_single` is True (default) and that the function has only one
output, then that output will be returned. Else, a list or tuple of
length 1 will be returned.
Examples Examples
-------- --------
e = x + y e = x + y
...@@ -209,10 +239,6 @@ class Linker: ...@@ -209,10 +239,6 @@ class Linker:
print fn(1.0, 2.0) # 3.0 print fn(1.0, 2.0) # 3.0
print e.data # 3.0 iff inplace == True (else unknown) print e.data # 3.0 iff inplace == True (else unknown)
If unpack_single is True (default) and that the function has only one
output, then that output will be returned. Else, a list or tuple of
length 1 will be returned.
""" """
thunk, inputs, outputs = self.make_thunk(**kwargs) thunk, inputs, outputs = self.make_thunk(**kwargs)
...@@ -231,23 +257,19 @@ class Linker: ...@@ -231,23 +257,19 @@ class Linker:
else: else:
return [variable.data for variable in outputs] return [variable.data for variable in outputs]
execute.thunk = thunk
execute.inputs = inputs
execute.outputs = outputs
return execute return execute
def schedule(self, fgraph: FunctionGraph) -> typing.List[Apply]: def schedule(self, fgraph: FunctionGraph) -> List[Apply]:
"""Runs the scheduler (if set) or the toposort on the FunctionGraph. """Runs the scheduler (if set) or the toposort on the FunctionGraph.
Parameters Parameters
---------- ----------
fgraph : FunctionGraph fgraph : :py:class:`aerasa.graph.fg.FunctionGraph`
A graph to compute the schedule for. A graph to compute the schedule for.
Returns Returns
------- -------
nodes : list of Apply nodes nodes : list of :py:class:`aesara.graph.basic.Apply` nodes
The result of the scheduling or toposort operation. The result of the scheduling or toposort operation.
""" """
if callable(self._scheduler): if callable(self._scheduler):
...@@ -262,14 +284,20 @@ class LocalLinker(Linker): ...@@ -262,14 +284,20 @@ class LocalLinker(Linker):
""" """
def make_thunk(self, input_storage=None, output_storage=None, storage_map=None): def make_thunk(
self,
input_storage: Optional[Any] = None,
output_storage: Optional[OutputStorageType] = None,
storage_map: Optional[StorageMapType] = None,
**kwargs,
) -> ThunkType:
return self.make_all( return self.make_all(
input_storage=input_storage, input_storage=input_storage,
output_storage=output_storage, output_storage=output_storage,
storage_map=storage_map, storage_map=storage_map,
)[:3] )[:3]
def make_all(self, input_storage, output_storage): def make_all(self, input_storage, output_storage, storage_map):
# By convention, subclasses of LocalLinker should implement this function! # By convention, subclasses of LocalLinker should implement this function!
# #
# This function should return a tuple of 5 things # This function should return a tuple of 5 things
...@@ -290,19 +318,27 @@ class PerformLinker(LocalLinker): ...@@ -290,19 +318,27 @@ class PerformLinker(LocalLinker):
""" """
def __init__(self, allow_gc=None, schedule=None): def __init__(self, allow_gc: Optional[bool] = None, schedule: None = None) -> None:
if allow_gc is None: if allow_gc is None:
allow_gc = config.allow_gc allow_gc = config.allow_gc
self.fgraph = None self.fgraph: Optional[FunctionGraph] = None
super().__init__(allow_gc=allow_gc, scheduler=schedule) super().__init__(allow_gc=allow_gc, scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None): def accept(
self,
fgraph: FunctionGraph,
no_recycling: Optional[
Union[List, Set["TensorVariable"], Set[Union[Variable, "TensorVariable"]]]
] = None,
profile: Optional[bool] = None,
) -> "PerformLinker":
""" """
Parameters Parameters
---------- ----------
fgraph fgraph : :py:class:`aesara.graph.fg.FunctionGraph` instance
A PerformLinker can have accepted one FunctionGraph instance at a time. A :py:class:`aesara.link.basic.PerformLinker` instance can have accepted
one :py:class:`aesara.graph.fg.FunctionGraph` instance at a time.
no_recycling no_recycling
WRITEME WRITEME
...@@ -323,16 +359,25 @@ class PerformLinker(LocalLinker): ...@@ -323,16 +359,25 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, input_storage=None, output_storage=None, storage_map=None): def make_all(
self,
input_storage: Optional[Any] = None,
output_storage: Optional[OutputStorageType] = None,
storage_map: Optional[StorageMapType] = None,
) -> Union[
Tuple[Callable, List[Container], List[Container], List[Callable], List[Apply]],
Tuple[Callable, List[Container], List[Container], List[Any], List[Any]],
]:
""" """
Returns Function to run all nodes, list of input containers, list of outputs Returns Function to run all nodes, list of input containers, list of outputs
Parameters Parameters
---------- ----------
input_storage input_storage : optional, list
list of storages corresponding to fgraph.inputs list of storages corresponding to fgraph.inputs
output_storage output_storage
list of storages corresponding to fgraph.outputs list of storages corresponding to fgraph.outputs
storage_map : iterable
Returns Returns
------- -------
...@@ -342,7 +387,7 @@ class PerformLinker(LocalLinker): ...@@ -342,7 +387,7 @@ class PerformLinker(LocalLinker):
(for all programs). (for all programs).
""" """
fgraph = self.fgraph fgraph: Any = self.fgraph
order = self.schedule(fgraph) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
...@@ -367,10 +412,7 @@ class PerformLinker(LocalLinker): ...@@ -367,10 +412,7 @@ class PerformLinker(LocalLinker):
thunks[-1].outputs = [storage_map[v] for v in node.outputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs]
computed, last_user = gc_helper(order) computed, last_user = gc_helper(order)
if self.allow_gc: post_thunk_old_storage: Any = [] if self.allow_gc else None
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
for node in order: for node in order:
if self.allow_gc: if self.allow_gc:
...@@ -452,12 +494,16 @@ class WrapLinker(Linker): ...@@ -452,12 +494,16 @@ class WrapLinker(Linker):
""" """
def __init__(self, linkers, wrapper): def __init__(
self,
linkers: Union[List["OpWiseCLinker"], List[PerformLinker], List["VMLinker"]],
wrapper: Callable,
) -> None:
self.fgraph = None self.fgraph = None
self.linkers = linkers self.linkers = linkers
self.wrapper = wrapper self.wrapper = wrapper
def __copy__(self): def __copy__(self) -> "WrapLinker":
""" """
Shallow copy of a WrapLinker. Shallow copy of a WrapLinker.
...@@ -483,12 +529,17 @@ class WrapLinker(Linker): ...@@ -483,12 +529,17 @@ class WrapLinker(Linker):
wrapper=self.wrapper, wrapper=self.wrapper,
) )
def accept(self, fgraph, no_recycling=None, profile=None): def accept(
self,
fgraph: FunctionGraph,
no_recycling: Optional[Union[Set["TensorVariable"], List]] = None,
profile: None = None,
) -> "WrapLinker":
""" """
Parameters Parameters
---------- ----------
fgraph : FunctionGraph fgraph : :py:class:`aesara.graph.fg.FunctionGraph`
The fgraph which we will link. The fgraph which we will link.
no_recycling : a list of Variables that belong to fgraph. no_recycling : a list of Variables that belong to fgraph.
If a Variable is in no_recycling, L{WrapLinker} will clear If a Variable is in no_recycling, L{WrapLinker} will clear
...@@ -506,10 +557,16 @@ class WrapLinker(Linker): ...@@ -506,10 +557,16 @@ class WrapLinker(Linker):
self.linkers = [linker.accept(fgraph, no_recycling) for linker in self.linkers] self.linkers = [linker.accept(fgraph, no_recycling) for linker in self.linkers]
return self return self
def pre(self, f, inputs, order, thunk_groups): def pre(
self,
f: "WrapLinker",
inputs: Union[List[ndarray], List[Optional[float]]],
order: List[Apply],
thunk_groups: List[Tuple[Callable]],
) -> None:
pass pass
def make_thunk(self, **kwargs): def make_thunk(self, **kwargs) -> ThunkType:
no_recycling = self.no_recycling no_recycling = self.no_recycling
make_all = [self.linkers[0].make_all(**kwargs)] make_all = [self.linkers[0].make_all(**kwargs)]
...@@ -559,7 +616,9 @@ class WrapLinker(Linker): ...@@ -559,7 +616,9 @@ class WrapLinker(Linker):
return f, inputs0, outputs0 return f, inputs0, outputs0
def WrapLinkerMany(linkers, wrappers): def WrapLinkerMany(
linkers: Union[List["OpWiseCLinker"], List["VMLinker"]], wrappers: List[Callable]
) -> WrapLinker:
""" """
Variant on WrapLinker that runs a series of wrapper functions instead of Variant on WrapLinker that runs a series of wrapper functions instead of
just one. just one.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论