提交 5335e729 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add type annotations to ScanArgs methods

上级 711843bb
......@@ -5,6 +5,7 @@ import dataclasses
import logging
import warnings
from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Callable, List, Optional, Set, Tuple, Union
import numpy as np
......@@ -27,6 +28,9 @@ from aesara.tensor.subtensor import set_subtensor
from aesara.tensor.var import TensorConstant
if TYPE_CHECKING:
from aesara.scan.op import ScanInfo
_logger = logging.getLogger("aesara.scan.utils")
......@@ -997,7 +1001,7 @@ class ScanArgs:
assert q == len(inner_outputs)
@staticmethod
def from_node(node, clone=False):
def from_node(node, clone=False) -> "ScanArgs":
from aesara.scan.op import Scan
if not isinstance(node.op, Scan):
......@@ -1013,7 +1017,7 @@ class ScanArgs:
)
@classmethod
def create_empty(cls):
def create_empty(cls) -> "ScanArgs":
from aesara.scan.op import ScanInfo
info = ScanInfo(
......@@ -1114,7 +1118,7 @@ class ScanArgs:
)
@property
def info(self):
def info(self) -> "ScanInfo":
from aesara.scan.op import ScanInfo
return ScanInfo(
......@@ -1133,22 +1137,27 @@ class ScanArgs:
mit_mot_out_slices=tuple(self.mit_mot_out_slices),
)
def get_alt_field(self, var_info, alt_prefix):
def get_alt_field(
self, var_info: Union[Variable, FieldInfo], alt_prefix: str
) -> Variable:
"""Get the alternate input/output field for a given element of `ScanArgs`.
For example, if `var_info` is in `ScanArgs.outer_out_sit_sot`, then
`get_alt_field(var_info, "inner_out")` returns the element corresponding
`var_info` in `ScanArgs.inner_out_sit_sot`.
For example, if `var_info` is in ``ScanArgs.outer_out_sit_sot``, then
``get_alt_field(var_info, "inner_out")`` returns the element corresponding
`var_info` in ``ScanArgs.inner_out_sit_sot``.
Parameters
----------
var_info: TensorVariable or FieldInfo
var_info:
The element for which we want the alternate
alt_prefix: str
alt_prefix:
The string prefix for the alternate field type. It can be one of
the following: "inner_out", "inner_in", "outer_in", and "outer_out".
the following: ``"inner_out"``, ``"inner_in"``, ``"outer_in"``, and
``"outer_out"``.
Outputs
-------
TensorVariable
Returns the alternate variable.
The alternate variable.
"""
if not isinstance(var_info, FieldInfo):
var_info = self.find_among_fields(var_info)
......@@ -1157,17 +1166,22 @@ class ScanArgs:
alt_var = getattr(self, "inner_out_{}".format(alt_type))[var_info.index]
return alt_var
def find_among_fields(self, i, field_filter=default_filter):
def find_among_fields(
self, i: Variable, field_filter: Callable[[str], bool] = default_filter
) -> Optional[Tuple[str, int, int, int]]:
"""Find the type and indices of the field containing a given element.
NOTE: This only returns the *first* field containing the given element.
Parameters
----------
i: theano.gof.graph.Variable
i:
The element to find among this object's fields.
field_filter: function
field_filter:
A function passed to `filter` that determines which fields to
consider. It must take a string field name and return a truthy
value.
Returns
-------
A tuple of length 4 containing the field name string, the first index,
......@@ -1204,7 +1218,9 @@ class ScanArgs:
return None
def _remove_from_fields(self, i, field_filter=default_filter):
def _remove_from_fields(
self, i: Variable, field_filter: Callable[[str], bool] = default_filter
) -> Optional[FieldInfo]:
field_info = self.find_among_fields(i, field_filter=field_filter)
......@@ -1218,7 +1234,9 @@ class ScanArgs:
return field_info
def get_dependent_nodes(self, i, seen=None):
def get_dependent_nodes(
self, i: Variable, seen: Optional[Set[int]] = None
) -> Set[Variable]:
if seen is None:
seen = {i}
else:
......@@ -1290,7 +1308,9 @@ class ScanArgs:
return dependent_nodes
def remove_from_fields(self, i, rm_dependents=True):
def remove_from_fields(
self, i: Variable, rm_dependents: bool = True
) -> List[FieldInfo]:
if rm_dependents:
vars_to_remove = self.get_dependent_nodes(i) | {i}
......@@ -1324,7 +1344,7 @@ class ScanArgs:
setattr(res, attr, copy.copy(getattr(self, attr)))
return res
def merge(self, other):
def merge(self, other: "ScanArgs") -> "ScanArgs":
res = copy.copy(self)
for attr in self.__dict__:
if (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论