提交 48ea2d4b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add missing Scan docstrings describing input and output layouts

上级 79112666
......@@ -204,6 +204,10 @@ def copy_var_format(var, as_var):
@dataclasses.dataclass(frozen=True)
class ScanInfo:
tap_array: tuple
"""
This is a tuple containing tuples of inner-output lag/lead values for the
mit-mots, mit-sots, and ``[-1]`` for each sit-sot.
"""
n_seqs: int
n_mit_mot: int
n_mit_mot_outs: int
......@@ -573,6 +577,36 @@ class ScanMethodsMixin:
class Scan(Op, ScanMethodsMixin, HasInnerGraph):
r"""An `Op` implementing `for` and `while` loops.
This `Op` has an "inner-graph" that represents the steps performed during
each iteration (or in the body of its loops). The vernacular for `Scan`
uses the prefix "inner-" for things pertaining to the
aforementioned inner-graph and "outer-" for node/`Apply`-level
things. There are inputs and outputs for both, and the "outer-graph" is
the graph in which the `Scan`-using `Apply` node exists.
The term "tap" refers to a relationship between the inputs and outputs of
an inner-graph. There are four types of taps and they characterize all the
supported "connection" patterns between inner-graph inputs and outputs:
- nit-sot (i.e. no inputs and a single output): A nit-sot is an output
variable of the inner-graph that is not fed back as an input to the
next iteration of the inner function.
- sit-sot (i.e. a single input and a single output): A sit-sot is an output
variable of the inner-graph that is fed back as an input to the next
iteration of the inner-graph.
- mit-sot (i.e. multiple inputs and a single output): A mit-sot is an
output variable of the inner-graph that is fed back as an input to
future iterations of the inner function (either multiple future
iterations or a single one that isn't the immediate next one).
- mit-mot (i.e. multiple inputs and multiple outputs): TODO
"""
def __init__(
self,
inputs: List[Variable],
......@@ -594,8 +628,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
----------
inputs
Inputs of the inner function of `Scan`.
These take the following general form:
sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + shared-inputs + non-sequences
where each term is a list of `Variable`\s.
outputs
Outputs of the inner function of `Scan`.
These take the following general form:
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs [+ while-condition]
where each term is a list of `Variable`\s.
info
A collection of information about the sequences and taps.
mode
......@@ -742,6 +788,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.seqs_arg_offset + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
)
self.nit_sot_arg_offset = self.shared_arg_offset + self.n_shared_outs
# XXX: This doesn't include `self.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
......@@ -801,15 +849,38 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def make_node(self, *inputs):
"""
Conventions:
inner_X - the variable corresponding to X in the inner function
of scan (the lambda function executed at every time
step)
outer_X - the variable corresponding to X in the outer graph,
i.e. the main graph (where the scan op lives)
inner_X_out - the variable representing the new value of X after
executing one step of scan (i.e. outputs given by
the inner function)
The `inputs` to this method take the following form:
sequences +
mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
shared-inputs +
nit-sots +
non-sequences
Note that some ``non-sequences`` can also be shared variables, and that
``nit-sots`` variables are the lengths of each nit-sot output, because
nit-sots have no input connections (by definition). Also, don't forget
that mit-[s|m]ots each have a distinct number of inputs and/or outputs.
The (outer-)inputs in the :class:`Apply` nodes created by this method
take the following concatenative form:
[n_steps] +
sequences +
mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
shared-inputs +
nit-sots +
non-sequences
The (outer-)outputs take the following form:
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs +
nit-sots +
shared-outputs
These outer-outputs essentially follow the same form as their
corresponding inner-outputs, excluding the final "while" condition
term.
"""
if not all(isinstance(i, Variable) for i in inputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论