提交 48ea2d4b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add missing Scan docstrings describing input and output layouts

上级 79112666
...@@ -204,6 +204,10 @@ def copy_var_format(var, as_var): ...@@ -204,6 +204,10 @@ def copy_var_format(var, as_var):
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ScanInfo: class ScanInfo:
tap_array: tuple tap_array: tuple
"""
This is a tuple containing tuples of inner-output lag/lead values for the
mit-mots, mit-sots, and ``[-1]`` for each sit-sot.
"""
n_seqs: int n_seqs: int
n_mit_mot: int n_mit_mot: int
n_mit_mot_outs: int n_mit_mot_outs: int
...@@ -573,6 +577,36 @@ class ScanMethodsMixin: ...@@ -573,6 +577,36 @@ class ScanMethodsMixin:
class Scan(Op, ScanMethodsMixin, HasInnerGraph): class Scan(Op, ScanMethodsMixin, HasInnerGraph):
r"""An `Op` implementing `for` and `while` loops.
This `Op` has an "inner-graph" that represents the steps performed during
each iteration (or in the body of its loops). The vernacular for `Scan`
uses the prefix "inner-" for things pertaining to the
aforementioned inner-graph and "outer-" for node/`Apply`-level
things. There are inputs and outputs for both, and the "outer-graph" is
the graph in which the `Scan`-using `Apply` node exists.
The term "tap" refers to a relationship between the inputs and outputs of
an inner-graph. There are four types of taps and they characterize all the
supported "connection" patterns between inner-graph inputs and outputs:
- nit-sot (i.e. no inputs and a single output): A nit-sot is an output
variable of the inner-graph that is not fed back as an input to the
next iteration of the inner function.
- sit-sot (i.e. a single input and a single output): A sit-sot is an output
variable of the inner-graph that is fed back as an input to the next
iteration of the inner-graph.
- mit-sot (i.e. multiple inputs and a single output): A mit-sot is an
output variable of the inner-graph that is fed back as an input to
future iterations of the inner function (either multiple future
iterations or a single one that isn't the immediate next one).
- mit-mot (i.e. multiple inputs and multiple outputs): TODO
"""
def __init__( def __init__(
self, self,
inputs: List[Variable], inputs: List[Variable],
...@@ -594,8 +628,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -594,8 +628,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
---------- ----------
inputs inputs
Inputs of the inner function of `Scan`. Inputs of the inner function of `Scan`.
These take the following general form:
sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + shared-inputs + non-sequences
where each term is a list of `Variable`\s.
outputs outputs
Outputs of the inner function of `Scan`. Outputs of the inner function of `Scan`.
These take the following general form:
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs [+ while-condition]
where each term is a list of `Variable`\s.
info info
A collection of information about the sequences and taps. A collection of information about the sequences and taps.
mode mode
...@@ -742,6 +788,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -742,6 +788,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.seqs_arg_offset + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot self.seqs_arg_offset + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
) )
self.nit_sot_arg_offset = self.shared_arg_offset + self.n_shared_outs self.nit_sot_arg_offset = self.shared_arg_offset + self.n_shared_outs
# XXX: This doesn't include `self.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
...@@ -801,15 +849,38 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -801,15 +849,38 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def make_node(self, *inputs): def make_node(self, *inputs):
""" """
Conventions: The `inputs` to this method take the following form:
inner_X - the variable corresponding to X in the inner function
of scan (the lambda function executed at every time sequences +
step) mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
outer_X - the variable corresponding to X in the outer graph, shared-inputs +
i.e. the main graph (where the scan op lives) nit-sots +
inner_X_out - the variable representing the new value of X after non-sequences
executing one step of scan (i.e. outputs given by
the inner function) Note that some ``non-sequences`` can also be shared variables, and that
``nit-sots`` variables are the lengths of each nit-sot output, because
nit-sots have no input connections (by definition). Also, don't forget
that mit-[s|m]ots each have a distinct number of inputs and/or outputs.
The (outer-)inputs in the :class:`Apply` nodes created by this method
take the following concatenative form:
[n_steps] +
sequences +
mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
shared-inputs +
nit-sots +
non-sequences
The (outer-)outputs take the following form:
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs +
nit-sots +
shared-outputs
These outer-outputs essentially follow the same form as their
corresponding inner-outputs, excluding the final "while" condition
term.
""" """
if not all(isinstance(i, Variable) for i in inputs): if not all(isinstance(i, Variable) for i in inputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论