提交 bcd52e63 authored 作者: elc45's avatar elc45 提交者: Ricardo Vieira

make scan dosctring render legibly

上级 83d029c2
......@@ -164,7 +164,7 @@ def _manage_output_api_change(outputs, updates, return_updates):
def scan(
fn,
fn: typing.Callable,
sequences=None,
outputs_info=None,
non_sequences=None,
......@@ -179,19 +179,18 @@ def scan(
return_list=False,
return_updates: bool = True,
):
r"""This function constructs and applies a `Scan` `Op` to the provided arguments.
r"""This function constructs and applies a scan operation to the provided arguments.
Parameters
----------
fn
`fn` is a function that describes the operations involved in one
fn : callable
A function that describes the operations involved in one
step of `scan`. `fn` should construct variables describing the
output of one iteration step. It should expect as input
`Variable`\s representing all the slices of the input sequences
and previous values of the outputs, as well as all other arguments
given to scan as `non_sequences`. The order in which scan passes
these variables to `fn` is the following :
* all time slices of the first sequence
* all time slices of the second sequence
* ...
......@@ -283,13 +282,11 @@ def scan(
Note that a number of steps--considered in here as the maximum
number of steps--is still required even though a condition is
passed. It is used to allocate memory if needed.
sequences
`sequences` is the list of `Variable`\s or ``dict``\s
describing the sequences `scan` has to iterate over. If a
sequence is given as wrapped in a ``dict``, then a set of optional
information can be provided about the sequence. The ``dict``
should have the following keys:
sequences : list of Variable or dict or None, optional
The sequences `scan` has to iterate over. If wrapped in
a ``dict``, then a set of optional information can be
provided about the sequence. The ``dict``should have the
following keys:
* ``input`` (*mandatory*) -- `Variable` representing the
sequence.
......@@ -301,14 +298,11 @@ def scan(
All `Variable`\s in the list `sequences` are automatically
wrapped into a ``dict`` where ``taps`` is set to ``[0]``
outputs_info
`outputs_info` is the list of `Variable`\s or ``dict``\s
describing the initial state of the outputs computed
recurrently. When the initial states are given as ``dict``\s,
optional information can be provided about the output corresponding
to those initial states. The ``dict`` should have the following
keys:
outputs_info : list of Variable or dict or None, optional
The initial state of the outputs computed recurrently.
If given as ``dict``\s, optional information can be
provided about the output corresponding to those initial states.
The ``dict`` should have the following keys:
* ``initial`` -- A `Variable` that represents the initial
state of a given output. In case the output is not computed
......@@ -359,59 +353,50 @@ def scan(
provided just for a subset of the outputs, an exception is
raised, because there is no convention on how scan should map
the provided information to the outputs of `fn`.
non_sequences
`non_sequences` is the list of arguments that are passed to
`fn` at each steps. One can choose to exclude variables
used in `fn` from this list, as long as they are part of the
computational graph, although--for clarity--this is *not* encouraged.
n_steps
`n_steps` is the number of steps to iterate given as an ``int``
non_sequences : list of Variable or None, optional
The arguments that are passed to `fn` at each step.
One can choose to exclude variables used in `fn` from this list,
as long as they are part of the computational graph, although
this is not encouraged for clarity.
n_steps : int or Variable or None, optional
The number of steps to iterate given as an ``int``
or a scalar `Variable`. If any of the input sequences do not have
enough elements, `scan` will raise an error. If the value is ``0``, the
outputs will have ``0`` rows. If `n_steps` is not provided, `scan` will
outputs will have ``0`` rows. If not provided, `scan` will
figure out the amount of steps it should run given its input
sequences. ``n_steps < 0`` is not supported anymore.
truncate_gradient
`truncate_gradient` is the number of steps to use in truncated
sequences. ``n_steps < 0`` is not supported.
truncate_gradient : int
The number of steps to use in truncated
back-propagation through time (BPTT). If you compute gradients through
a `Scan` `Op`, they are computed using BPTT. By providing a different
value then ``-1``, you choose to use truncated BPTT instead of classical
BPTT, where you go for only `truncate_gradient` number of steps back in
time.
go_backwards
`go_backwards` is a flag indicating if `scan` should go
go_backwards : bool
Indicates if `scan` should go
backwards through the sequences. If you think of each sequence
as indexed by time, making this flag ``True`` would mean that
`scan` goes back in time, namely that for any sequence it
starts from the end and goes towards ``0``.
name
name : str or None, optional
When profiling `scan`, it is helpful to provide a name for any
instance of `scan`.
For example, the profiler will produce an overall profile of your code
as well as profiles for the computation of one step of each instance of
`Scan`. The `name` of the instance appears in those profiles and can
greatly help to disambiguate information.
mode
instance of `scan`. For example, the profiler will produce an
overall profile of your code as well as profiles for the computation
of one step of each instance of `Scan`. The `name` of the instance
appears in those profiles and can greatly help to disambiguate information.
mode : str or None, optional
The mode used to compile the inner-graph.
If you prefer the computations of one step of `scan` to be done
differently then the entire function, you can use this parameter to
describe how the computations in this loop are done (see
`pytensor.function` for details about possible values and their meaning).
profile
profile : bool or str
If ``True`` or a non-empty string, a profile object will be created and
attached to the inner graph of `Scan`. When `profile` is ``True``, the
profiler results will use the name of the `Scan` instance, otherwise it
will use the passed string. The profiler only collects and prints
information when running the inner graph with the `CVM` `Linker`.
allow_gc
allow_gc : bool or None, optional
Set the value of `allow_gc` for the internal graph of the `Scan`. If
set to ``None``, this will use the value of
`pytensor.config.scan__allow_gc`.
......@@ -425,24 +410,21 @@ def scan(
speed up allocation of the subsequent iterations. All those temporary
allocations are freed at the end of all iterations; this is what the
flag `pytensor.config.allow_gc` means.
strict
strict : bool
If ``True``, all the shared variables used in `fn` must be provided as a
part of `non_sequences` or `sequences`.
return_list
return_list : bool
If ``True``, will always return a ``list``, even if there is only one output.
return_updates : bool, optional
If ``True`` (default), the returned tuple includes the updates dictionary.
Returns
-------
tuple
``tuple`` of the form ``(outputs, updates)``.
``outputs`` is either a `Variable` or a ``list`` of `Variable`\s
representing the outputs in the same order as in `outputs_info`.
``updates`` is a subclass of ``dict`` specifying the update rules for
all shared variables used in `Scan`.
This ``dict`` should be passed to `pytensor.function` when you compile
your function.
outputs : Variable or list of Variable
The outputs of the scan, in the same order as `outputs_info`.
updates : dict
Dictionary of update rules for shared variables used in the scan.
Pass this to `pytensor.function` when compiling your function.
"""
# General observation : this code is executed only once, at creation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论