提交 b7cf7933 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/scan_module/scan_utils.py

上级 deabd346
"""
This module provides utility functions for the Scan Op
This module provides utility functions for the Scan Op.
See scan.py for details on scan.
See scan.py for details on scan
"""
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
......@@ -43,6 +44,7 @@ def safe_new(x, tag='', dtype=None):
by gradient, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original
graph and the newly constructed graph.
"""
if hasattr(x, 'name') and x.name is not None:
nw_name = x.name + tag
......@@ -117,21 +119,28 @@ class until(object):
between the condition and the list of outputs ( unless we enforce and
order, but since this was not impose up to know it can make quite a bit
of code to fail).
"""
def __init__(self, condition):
self.condition = tensor.as_tensor_variable(condition)
assert self.condition.ndim == 0
def traverse(out, x, x_copy, d, visited=None):
''' Function used by scan to parse the tree and figure out which nodes
it needs to replace. There are two options :
"""
Function used by scan to parse the tree and figure out which nodes
it needs to replace.
There are two options :
1) x and x_copy or on host, then you would replace x with x_copy
2) x is on gpu, x_copy on host, then you need to replace
host_from_gpu(x) with x_copy
This happens because initially shared variables are on GPU .. which is
This happens because initially shared variables are on GPU... which is
fine for the main computational graph but confuses things a bit for the
inner graph of scan '''
inner graph of scan.
"""
# ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs.
# if a ``visited`` set is given, it will be updated in-place so the callee
......@@ -191,25 +200,25 @@ def clone(output,
share_inputs=True,
copy_inputs=DEPRECATED_ARG):
"""
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
Function that allows replacing subgraphs of a computational graph.
It returns a copy of the initial subgraph with the corresponding
substitutions.
:type output: Theano Variables (or Theano expressions)
:param outputs: Theano expression that represents the computational
graph
Parameters
----------
output : Theano Variables (or Theano expressions)
Theano expression that represents the computational graph.
replace : dict
Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool
If True, use the same inputs (and shared variables) as the original
graph. If False, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
copy_inputs
Deprecated, use share_inputs.
:type replace: dict
:param replace: dictionary describing which subgraphs should be
replaced by what
:type share_inputs: bool
:param share_inputs: If True, use the same inputs (and shared variables)
as the original graph. If False, clone them. Note that cloned
shared variables still use the same underlying storage, so they
will always have the same value.
:param copy_inputs: deprecated, use share_inputs.
"""
if copy_inputs is not DEPRECATED_ARG:
warnings.warn('In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`')
......@@ -251,7 +260,7 @@ def get_updates_and_outputs(ls):
"""
This function tries to recognize the updates OrderedDict, the
list of outputs and the stopping condition returned by the
lambda expression and arrange them in a predefined order
lambda expression and arrange them in a predefined order.
WRITEME: what is the type of ls? how is it formatted?
if it's not in the predefined order already, how does
......@@ -297,6 +306,7 @@ def get_updates_and_outputs(ls):
Return True iff `x` is made only of lists, tuples, dictionaries, Theano
variables or `theano.scan_module.until` objects.
"""
# Is `x` a container we can iterate on?
iter_on = None
......@@ -390,10 +400,11 @@ def isNaN_or_Inf_or_None(x):
def expand(tensor_var, size):
'''
"""
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding 0s at the end of the tensor.
'''
"""
# Corner case that I might use in an optimization
if size == 0:
return tensor_var
......@@ -406,7 +417,7 @@ def expand(tensor_var, size):
def equal_computations(xs, ys, in_xs=None, in_ys=None):
'''Checks if Theano graphs represent the same computations.
"""Checks if Theano graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
......@@ -420,7 +431,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
`ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`.
'''
"""
assert len(xs) == len(ys)
if in_xs is None:
in_xs = []
......@@ -460,14 +471,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Explore the two graphs, in parallel, depth first, comparing the nodes
# along the way for equality.
def compare_nodes(nd_x, nd_y, common, different):
''' Compare two nodes to determine if they perform equal computation.
"""
Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal
computation.
NOTE : This function relies on the variable common to cache
results to be more efficient.
'''
"""
if nd_x.op != nd_y.op:
return False
......@@ -537,13 +550,14 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
def infer_shape(outs, inputs, input_shapes):
'''
Compute the shape of the outputs given the shape of the inputs
of a theano graph.
"""
Compute the shape of the outputs given the shape of the inputs of a theano
graph.
We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function.
'''
"""
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
......@@ -560,10 +574,10 @@ def infer_shape(outs, inputs, input_shapes):
shape_feature.set_shape(inp, inp_shp)
def local_traverse(out):
'''
"""
Go back in the graph, from out, adding computable shapes to shape_of.
'''
"""
if out in shape_feature.shape_of:
# Its shape is already known
return
......@@ -589,14 +603,18 @@ def infer_shape(outs, inputs, input_shapes):
class Validator(object):
def __init__(self, valid=None, invalid=None, valid_equivalent=None):
'''
Check if variables can be expressed without using variables in invalid.
"""
Check if variables can be expressed without using variables in invalid.
Parameters
----------
valid_equivalent
Provides a dictionary mapping some invalid variables to valid ones that
can be used instead.
init_valid_equivalent provides a dictionary mapping some invalid
variables to valid ones that can be used instead.
'''
"""
def __init__(self, valid=None, invalid=None, valid_equivalent=None):
if valid is None:
valid = []
if invalid is None:
......@@ -616,13 +634,14 @@ class Validator(object):
self.invalid.update(list(valid_equivalent.keys()))
def check(self, out):
'''
"""
Go backwards in the graph, from out, and check if out is valid.
If out is a valid node, (out, True) is returned.
If out is not valid, but has an equivalent e, (e, False) is returned.
If out is not valid and has no equivalent, None is returned.
'''
"""
if out in self.valid:
return out, True
elif out in self.valid_equivalent:
......@@ -667,12 +686,13 @@ class Validator(object):
def scan_can_remove_outs(op, out_idxs):
'''
"""
Looks at all outputs defined by indices ``out_idxs`` and see whom can be
removed from the scan op without affecting the rest. Return two lists,
the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed.
'''
"""
non_removable = [o for i, o in enumerate(op.outputs) if i not in
out_idxs]
required_inputs = gof.graph.inputs(non_removable)
......@@ -706,7 +726,7 @@ def scan_can_remove_outs(op, out_idxs):
def compress_outs(op, not_required, inputs):
'''
"""
Helpful function that gets a Scan op, a list of indices indicating
which outputs are not required anymore and should be removed, and
a list of inputs to the apply node corresponding to the scan op and
......@@ -714,7 +734,8 @@ def compress_outs(op, not_required, inputs):
the indicated outputs are eliminated. Note that eliminating an output
means removing its inputs from the inner funciton and from the
node inputs, and changing the dictionary.
'''
"""
info = OrderedDict()
info['tap_array'] = []
info['n_seqs'] = op.info['n_seqs']
......@@ -852,6 +873,7 @@ def compress_outs(op, not_required, inputs):
def find_up(l_node, f_node):
r"""
Goes up in the graph and returns True if a node in nodes is found.
"""
if isinstance(l_node, gof.Apply):
l_outs = l_node.outputs
......@@ -866,8 +888,9 @@ def reconstruct_graph(inputs, outputs, tag=None):
"""
Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with
new variables of the same type, and returns those ( in the same
new variables of the same type, and returns those (in the same
order as the original inputs).
"""
if tag is None:
tag = ''
......@@ -885,7 +908,11 @@ def reconstruct_graph(inputs, outputs, tag=None):
class scan_args(object):
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
"""
Parses the inputs and outputs of scan in an easy to manipulate format.
"""
def __init__(self, outer_inputs, outer_outputs,
_inner_inputs, _inner_outputs, info):
self.n_steps = outer_inputs[0]
......@@ -1070,17 +1097,22 @@ class scan_args(object):
def forced_replace(out, x, y):
"""
:param out: Theano Variable
:param x: Theano Variable
:param y: Theano Variable
This function checks all internal values of the graph that computes the
variable ``out`` for occurances of values identical with ``x``. If such
occurances are encountered then they are replaced with variable ``y``.
For example:
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
Check all internal values of the graph that compute the variable ``out``
for occurrences of values identical with ``x``. If such occurrences are
encountered then they are replaced with variable ``y``.
Parameters
----------
out : Theano Variable
x : Theano Variable
y : Theano Variable
Examples
--------
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
"""
if out is None:
return None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论