提交 5335e729 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add type annotations to ScanArgs methods

上级 711843bb
...@@ -5,6 +5,7 @@ import dataclasses ...@@ -5,6 +5,7 @@ import dataclasses
import logging import logging
import warnings import warnings
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Callable, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
...@@ -27,6 +28,9 @@ from aesara.tensor.subtensor import set_subtensor ...@@ -27,6 +28,9 @@ from aesara.tensor.subtensor import set_subtensor
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
if TYPE_CHECKING:
from aesara.scan.op import ScanInfo
_logger = logging.getLogger("aesara.scan.utils") _logger = logging.getLogger("aesara.scan.utils")
...@@ -997,7 +1001,7 @@ class ScanArgs: ...@@ -997,7 +1001,7 @@ class ScanArgs:
assert q == len(inner_outputs) assert q == len(inner_outputs)
@staticmethod @staticmethod
def from_node(node, clone=False): def from_node(node, clone=False) -> "ScanArgs":
from aesara.scan.op import Scan from aesara.scan.op import Scan
if not isinstance(node.op, Scan): if not isinstance(node.op, Scan):
...@@ -1013,7 +1017,7 @@ class ScanArgs: ...@@ -1013,7 +1017,7 @@ class ScanArgs:
) )
@classmethod @classmethod
def create_empty(cls): def create_empty(cls) -> "ScanArgs":
from aesara.scan.op import ScanInfo from aesara.scan.op import ScanInfo
info = ScanInfo( info = ScanInfo(
...@@ -1114,7 +1118,7 @@ class ScanArgs: ...@@ -1114,7 +1118,7 @@ class ScanArgs:
) )
@property @property
def info(self): def info(self) -> "ScanInfo":
from aesara.scan.op import ScanInfo from aesara.scan.op import ScanInfo
return ScanInfo( return ScanInfo(
...@@ -1133,22 +1137,27 @@ class ScanArgs: ...@@ -1133,22 +1137,27 @@ class ScanArgs:
mit_mot_out_slices=tuple(self.mit_mot_out_slices), mit_mot_out_slices=tuple(self.mit_mot_out_slices),
) )
def get_alt_field(self, var_info, alt_prefix): def get_alt_field(
self, var_info: Union[Variable, FieldInfo], alt_prefix: str
) -> Variable:
"""Get the alternate input/output field for a given element of `ScanArgs`. """Get the alternate input/output field for a given element of `ScanArgs`.
For example, if `var_info` is in `ScanArgs.outer_out_sit_sot`, then
`get_alt_field(var_info, "inner_out")` returns the element corresponding For example, if `var_info` is in ``ScanArgs.outer_out_sit_sot``, then
`var_info` in `ScanArgs.inner_out_sit_sot`. ``get_alt_field(var_info, "inner_out")`` returns the element corresponding
`var_info` in ``ScanArgs.inner_out_sit_sot``.
Parameters Parameters
---------- ----------
var_info: TensorVariable or FieldInfo var_info:
The element for which we want the alternate The element for which we want the alternate
alt_prefix: str alt_prefix:
The string prefix for the alternate field type. It can be one of The string prefix for the alternate field type. It can be one of
the following: "inner_out", "inner_in", "outer_in", and "outer_out". the following: ``"inner_out"``, ``"inner_in"``, ``"outer_in"``, and
``"outer_out"``.
Outputs Outputs
------- -------
TensorVariable The alternate variable.
Returns the alternate variable.
""" """
if not isinstance(var_info, FieldInfo): if not isinstance(var_info, FieldInfo):
var_info = self.find_among_fields(var_info) var_info = self.find_among_fields(var_info)
...@@ -1157,17 +1166,22 @@ class ScanArgs: ...@@ -1157,17 +1166,22 @@ class ScanArgs:
alt_var = getattr(self, "inner_out_{}".format(alt_type))[var_info.index] alt_var = getattr(self, "inner_out_{}".format(alt_type))[var_info.index]
return alt_var return alt_var
def find_among_fields(self, i, field_filter=default_filter): def find_among_fields(
self, i: Variable, field_filter: Callable[[str], bool] = default_filter
) -> Optional[Tuple[str, int, int, int]]:
"""Find the type and indices of the field containing a given element. """Find the type and indices of the field containing a given element.
NOTE: This only returns the *first* field containing the given element. NOTE: This only returns the *first* field containing the given element.
Parameters Parameters
---------- ----------
i: theano.gof.graph.Variable i:
The element to find among this object's fields. The element to find among this object's fields.
field_filter: function field_filter:
A function passed to `filter` that determines which fields to A function passed to `filter` that determines which fields to
consider. It must take a string field name and return a truthy consider. It must take a string field name and return a truthy
value. value.
Returns Returns
------- -------
A tuple of length 4 containing the field name string, the first index, A tuple of length 4 containing the field name string, the first index,
...@@ -1204,7 +1218,9 @@ class ScanArgs: ...@@ -1204,7 +1218,9 @@ class ScanArgs:
return None return None
def _remove_from_fields(self, i, field_filter=default_filter): def _remove_from_fields(
self, i: Variable, field_filter: Callable[[str], bool] = default_filter
) -> Optional[FieldInfo]:
field_info = self.find_among_fields(i, field_filter=field_filter) field_info = self.find_among_fields(i, field_filter=field_filter)
...@@ -1218,7 +1234,9 @@ class ScanArgs: ...@@ -1218,7 +1234,9 @@ class ScanArgs:
return field_info return field_info
def get_dependent_nodes(self, i, seen=None): def get_dependent_nodes(
self, i: Variable, seen: Optional[Set[int]] = None
) -> Set[Variable]:
if seen is None: if seen is None:
seen = {i} seen = {i}
else: else:
...@@ -1290,7 +1308,9 @@ class ScanArgs: ...@@ -1290,7 +1308,9 @@ class ScanArgs:
return dependent_nodes return dependent_nodes
def remove_from_fields(self, i, rm_dependents=True): def remove_from_fields(
self, i: Variable, rm_dependents: bool = True
) -> List[FieldInfo]:
if rm_dependents: if rm_dependents:
vars_to_remove = self.get_dependent_nodes(i) | {i} vars_to_remove = self.get_dependent_nodes(i) | {i}
...@@ -1324,7 +1344,7 @@ class ScanArgs: ...@@ -1324,7 +1344,7 @@ class ScanArgs:
setattr(res, attr, copy.copy(getattr(self, attr))) setattr(res, attr, copy.copy(getattr(self, attr)))
return res return res
def merge(self, other): def merge(self, other: "ScanArgs") -> "ScanArgs":
res = copy.copy(self) res = copy.copy(self)
for attr in self.__dict__: for attr in self.__dict__:
if ( if (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论