Unverified 提交 3e58e7ce authored 作者: Nicolas Legrand's avatar Nicolas Legrand 提交者: GitHub

Add type hints + docstrings details to theano.link.basics.py (#272)

* Add type hints and docstrings for theano.link.basics.py * Rename theano to aesara in docstrings * Add `StorageMapType` and `OutputStorageType` * Fix `make_thunk` * Remove `from __future__ import annotations`
上级 37444647
import typing
from copy import copy, deepcopy
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
NoReturn,
Optional,
Set,
Tuple,
Union,
)
from numpy import ndarray
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import CType
from aesara.graph.utils import MetaObject
from aesara.link.utils import gc_helper, map_storage, raise_with_op, streamline
from aesara.utils import deprecated, difference, to_return_values
if TYPE_CHECKING:
from aesara.link.c.basic import OpWiseCLinker
from aesara.link.vm import VMLinker
from aesara.tensor.var import TensorVariable
StorageMapType = Dict[Variable, List[Optional[Union[ndarray, slice]]]]
OutputStorageType = List[Optional[List[Any]]]
ThunkType = Tuple[Callable[[], NoReturn], List["Container"], List["Container"]]
class Container:
"""
This class joins a variable with its computed value.
......@@ -17,14 +42,14 @@ class Container:
Parameters
----------
r : a Variable or a Type
r : :py:class:`aerasa.graph.utils.MetaObject`
storage
A list of length 1, whose element is the value for `r`.
readonly : bool
True indicates that this should not be setable by Function[r] = val.
strict : bool
If True, we don't allow type casting.
allow_downcast
allow_downcast : bool
If True (and `strict` is False), allow upcasting of type, but not
downcasting. If False, prevent it. If None (default), allows only
downcasting of float to floatX scalar.
......@@ -35,14 +60,14 @@ class Container:
def __init__(
self,
r,
storage,
r: MetaObject,
storage: Any,
*,
readonly=False,
strict=False,
allow_downcast=None,
name=None,
):
readonly: bool = False,
strict: bool = False,
allow_downcast: Optional[bool] = None,
name: Optional[str] = None,
) -> None:
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
if isinstance(r, CType):
......@@ -60,10 +85,10 @@ class Container:
self.strict = strict
self.allow_downcast = allow_downcast
def __get__(self):
def __get__(self) -> Any:
return self.storage[0]
def __set__(self, value):
def __set__(self, value: Any) -> None:
if self.readonly:
raise Exception(f"Cannot set readonly storage: {self.name}")
try:
......@@ -98,7 +123,7 @@ class Container:
def __repr__(self):
return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo):
def __deepcopy__(self, memo: Dict[int, Any]) -> "Container":
data_was_in_memo = id(self.storage[0]) in memo
r = type(self)(
deepcopy(self.type, memo=memo),
......@@ -132,41 +157,39 @@ class Linker:
allow_gc : optional, bool
Configures if garbage collection is enabled.
scheduler : callable
A scheduling function that takes a FunctionGraph and returns a list of Apply nodes.
Defaults to the .toposort() method of the FunctionGraph.
A scheduling function that takes a FunctionGraph and returns
a list of Apply nodes. Defaults to the .toposort() method of
the FunctionGraph.
"""
def __init__(
self,
*,
allow_gc: typing.Optional[bool] = None,
scheduler: typing.Callable[[FunctionGraph], typing.List[Apply]] = None,
):
allow_gc: Optional[bool] = None,
scheduler: Callable[[FunctionGraph], List[Apply]] = None,
) -> None:
self._allow_gc = allow_gc
self._scheduler = scheduler
super().__init__()
@property
def allow_gc(self) -> typing.Optional[bool]:
def allow_gc(self) -> Optional[bool]:
"""Determines if the linker may allow garbage collection.
None means undefined.
Returns
-------
_allow_gc : optional, bool
None means undefined.
"""
return self._allow_gc
def clone(self, allow_gc: typing.Optional[bool] = None):
def clone(self, allow_gc: Optional[bool] = None) -> "Linker":
new = copy(self)
if allow_gc is not None:
new._allow_gc = allow_gc
return new
def make_thunk(
self,
) -> typing.Tuple[
typing.Callable[[], typing.NoReturn],
typing.List[Container],
typing.List[Container],
]:
def make_thunk(self, **kwargs) -> ThunkType:
"""
This function must return a triplet (function, input_variables,
output_variables) where function is a thunk that operates on the
......@@ -193,7 +216,7 @@ class Linker:
)
@deprecated("Marked for deletion. Only tests use it.")
def make_function(self, unpack_single=True, **kwargs):
def make_function(self, unpack_single: bool = True, **kwargs) -> Callable:
"""
Returns a function that takes values corresponding to the inputs of the
fgraph used by this L{Linker} and returns values corresponding the the
......@@ -201,6 +224,13 @@ class Linker:
operate in the same storage the fgraph uses, else independent storage
will be allocated for the function.
Parameters
----------
unpack_single : bool
If `unpack_single` is True (default) and that the function has only one
output, then that output will be returned. Else, a list or tuple of
length 1 will be returned.
Examples
--------
e = x + y
......@@ -209,10 +239,6 @@ class Linker:
print fn(1.0, 2.0) # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
If unpack_single is True (default) and that the function has only one
output, then that output will be returned. Else, a list or tuple of
length 1 will be returned.
"""
thunk, inputs, outputs = self.make_thunk(**kwargs)
......@@ -231,23 +257,19 @@ class Linker:
else:
return [variable.data for variable in outputs]
execute.thunk = thunk
execute.inputs = inputs
execute.outputs = outputs
return execute
def schedule(self, fgraph: FunctionGraph) -> typing.List[Apply]:
def schedule(self, fgraph: FunctionGraph) -> List[Apply]:
"""Runs the scheduler (if set) or the toposort on the FunctionGraph.
Parameters
----------
fgraph : FunctionGraph
fgraph : :py:class:`aerasa.graph.fg.FunctionGraph`
A graph to compute the schedule for.
Returns
-------
nodes : list of Apply nodes
nodes : list of :py:class:`aesara.graph.basic.Apply` nodes
The result of the scheduling or toposort operation.
"""
if callable(self._scheduler):
......@@ -262,14 +284,20 @@ class LocalLinker(Linker):
"""
def make_thunk(self, input_storage=None, output_storage=None, storage_map=None):
def make_thunk(
self,
input_storage: Optional[Any] = None,
output_storage: Optional[OutputStorageType] = None,
storage_map: Optional[StorageMapType] = None,
**kwargs,
) -> ThunkType:
return self.make_all(
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)[:3]
def make_all(self, input_storage, output_storage):
def make_all(self, input_storage, output_storage, storage_map):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
......@@ -290,19 +318,27 @@ class PerformLinker(LocalLinker):
"""
def __init__(self, allow_gc=None, schedule=None):
def __init__(self, allow_gc: Optional[bool] = None, schedule: None = None) -> None:
if allow_gc is None:
allow_gc = config.allow_gc
self.fgraph = None
self.fgraph: Optional[FunctionGraph] = None
super().__init__(allow_gc=allow_gc, scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
def accept(
self,
fgraph: FunctionGraph,
no_recycling: Optional[
Union[List, Set["TensorVariable"], Set[Union[Variable, "TensorVariable"]]]
] = None,
profile: Optional[bool] = None,
) -> "PerformLinker":
"""
Parameters
----------
fgraph
A PerformLinker can have accepted one FunctionGraph instance at a time.
fgraph : :py:class:`aesara.graph.fg.FunctionGraph` instance
A :py:class:`aesara.link.basic.PerformLinker` instance can have accepted
one :py:class:`aesara.graph.fg.FunctionGraph` instance at a time.
no_recycling
WRITEME
......@@ -323,16 +359,25 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling
return self
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
def make_all(
self,
input_storage: Optional[Any] = None,
output_storage: Optional[OutputStorageType] = None,
storage_map: Optional[StorageMapType] = None,
) -> Union[
Tuple[Callable, List[Container], List[Container], List[Callable], List[Apply]],
Tuple[Callable, List[Container], List[Container], List[Any], List[Any]],
]:
"""
Returns Function to run all nodes, list of input containers, list of outputs
Parameters
----------
input_storage
input_storage : optional, list
list of storages corresponding to fgraph.inputs
output_storage
list of storages corresponding to fgraph.outputs
storage_map : iterable
Returns
-------
......@@ -342,7 +387,7 @@ class PerformLinker(LocalLinker):
(for all programs).
"""
fgraph = self.fgraph
fgraph: Any = self.fgraph
order = self.schedule(fgraph)
no_recycling = self.no_recycling
......@@ -367,10 +412,7 @@ class PerformLinker(LocalLinker):
thunks[-1].outputs = [storage_map[v] for v in node.outputs]
computed, last_user = gc_helper(order)
if self.allow_gc:
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
post_thunk_old_storage: Any = [] if self.allow_gc else None
for node in order:
if self.allow_gc:
......@@ -452,12 +494,16 @@ class WrapLinker(Linker):
"""
def __init__(self, linkers, wrapper):
def __init__(
self,
linkers: Union[List["OpWiseCLinker"], List[PerformLinker], List["VMLinker"]],
wrapper: Callable,
) -> None:
self.fgraph = None
self.linkers = linkers
self.wrapper = wrapper
def __copy__(self):
def __copy__(self) -> "WrapLinker":
"""
Shallow copy of a WrapLinker.
......@@ -483,12 +529,17 @@ class WrapLinker(Linker):
wrapper=self.wrapper,
)
def accept(self, fgraph, no_recycling=None, profile=None):
def accept(
self,
fgraph: FunctionGraph,
no_recycling: Optional[Union[Set["TensorVariable"], List]] = None,
profile: None = None,
) -> "WrapLinker":
"""
Parameters
----------
fgraph : FunctionGraph
fgraph : :py:class:`aesara.graph.fg.FunctionGraph`
The fgraph which we will link.
no_recycling : a list of Variables that belong to fgraph.
If a Variable is in no_recycling, L{WrapLinker} will clear
......@@ -506,10 +557,16 @@ class WrapLinker(Linker):
self.linkers = [linker.accept(fgraph, no_recycling) for linker in self.linkers]
return self
def pre(self, f, inputs, order, thunk_groups):
def pre(
self,
f: "WrapLinker",
inputs: Union[List[ndarray], List[Optional[float]]],
order: List[Apply],
thunk_groups: List[Tuple[Callable]],
) -> None:
pass
def make_thunk(self, **kwargs):
def make_thunk(self, **kwargs) -> ThunkType:
no_recycling = self.no_recycling
make_all = [self.linkers[0].make_all(**kwargs)]
......@@ -559,7 +616,9 @@ class WrapLinker(Linker):
return f, inputs0, outputs0
def WrapLinkerMany(linkers, wrappers):
def WrapLinkerMany(
linkers: Union[List["OpWiseCLinker"], List["VMLinker"]], wrappers: List[Callable]
) -> WrapLinker:
"""
Variant on WrapLinker that runs a series of wrapper functions instead of
just one.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论