提交 bacd93af authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/sandbox/scan_module/scan_utils.py

上级 8d4e690a
"""
This module provides utility functions for the Scan Op
This module provides utility functions for the Scan Op.
See scan.py for details on scan.
See scan.py for details on scan
"""
from __future__ import print_function
__docformat__ = 'restructedtext en'
......@@ -41,8 +42,11 @@ def expand(tensor_var, size):
``tensor_var``, namely:
rval[:d1] = tensor_var
:param tensor_var: Theano tensor variable
:param size: int
Parameters
----------
tensor_var : Theano tensor variable.
size : int
"""
# Corner case that I might use in an optimization
if size == 0:
......@@ -57,7 +61,8 @@ def expand(tensor_var, size):
def to_list(ls):
"""
Converts ``ls`` to list if it is a tuple, or wraps ``ls`` into a list if
it is not a list already
it is not a list already.
"""
if isinstance(ls, (list, tuple)):
return list(ls)
......@@ -70,7 +75,9 @@ class until(object):
Theano can end on a condition. In order to differentiate this condition
from the other outputs of scan, this class is used to wrap the condition
around it.
"""
def __init__(self, condition):
self.condition = tensor.as_tensor_variable(condition)
assert self.condition.ndim == 0
......@@ -78,10 +85,12 @@ class until(object):
def get_updates_and_outputs(ls):
"""
Parses the list ``ls`` into outputs and updates. The semantics
of ``ls`` is defined by the constructive function of scan.
Parses the list ``ls`` into outputs and updates.
The semantics of ``ls`` is defined by the constructive function of scan.
The elemets of ``ls`` are either a list of expressions representing the
outputs/states, a dictionary of updates or a condition.
"""
def is_list_outputs(elem):
if (isinstance(elem, (list, tuple)) and
......@@ -150,23 +159,23 @@ def get_updates_and_outputs(ls):
def clone(output, replace=None, strict=True, share_inputs=True):
"""
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
Function that allows replacing subgraphs of a computational graph.
It returns a copy of the initial subgraph with the corresponding
substitutions.
:type output: Theano Variables (or Theano expressions)
:param outputs: Theano expression that represents the computational
graph
:type replace: dict
:param replace: dictionary describing which subgraphs should be
replaced by what
Parameters
----------
output : Theano Variables (or Theano expressions)
Theano expression that represents the computational graph.
replace: dict
Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool
If True, use the same inputs (and shared variables) as the original
graph. If False, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
:type share_inputs: bool
:param share_inputs: If True, use the same inputs (and shared variables)
as the original graph. If False, clone them. Note that cloned
shared variables still use the same underlying storage, so they
will always have the same value.
"""
inps, outs, other_stuff = rebuild_collect_shared(output,
[],
......@@ -189,6 +198,7 @@ def canonical_arguments(sequences,
Mainly it makes sure that arguments are given as lists of dictionaries,
and that the different fields of of a dictionary are set to default
value if the user has not provided any.
"""
states_info = to_list(outputs_info)
parameters = [tensor.as_tensor_variable(x) for x in to_list(non_sequences)]
......@@ -303,13 +313,14 @@ def canonical_arguments(sequences,
def infer_shape(outs, inputs, input_shapes):
'''
"""
Compute the shape of the outputs given the shape of the inputs
of a theano graph.
We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function.
'''
We do it this way to avoid compiling the inner function just to get the
shape. Changes to ShapeFeature could require changes in this function.
"""
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
......@@ -326,9 +337,10 @@ def infer_shape(outs, inputs, input_shapes):
shape_feature.set_shape(inp, inp_shp)
def local_traverse(out):
'''
"""
Go back in the graph, from out, adding computable shapes to shape_of.
'''
"""
if out in shape_feature.shape_of:
# Its shape is already known
......@@ -358,14 +370,17 @@ def allocate_memory(T, y_info, y):
"""
Allocates memory for an output of scan.
:param T: scalar
Variable representing the number of steps scan will run
:param y_info: dict
Parameters
----------
T : scalar
Variable representing the number of steps scan will run.
y_info : dict
Dictionary describing the output (more specifically describing shape
information for the output
:param y: Tensor variable
information for the output.
y : Tensor variable
Expression describing the computation resulting in out entry of y.
It can be used to infer the shape of y
It can be used to infer the shape of y.
"""
if 'shape' in y_info:
return tensor.zeros([T, ] + list(y_info['shape']),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论