提交 b7cf7933 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/scan_module/scan_utils.py

上级 deabd346
""" """
This module provides utility functions for the Scan Op This module provides utility functions for the Scan Op.
See scan.py for details on scan.
See scan.py for details on scan
""" """
__docformat__ = 'restructedtext en' __docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu " __authors__ = ("Razvan Pascanu "
...@@ -43,6 +44,7 @@ def safe_new(x, tag='', dtype=None): ...@@ -43,6 +44,7 @@ def safe_new(x, tag='', dtype=None):
by gradient, or the R-op to construct new variables for the inputs of by gradient, or the R-op to construct new variables for the inputs of
the inner graph such that there is no interference between the original the inner graph such that there is no interference between the original
graph and the newly constructed graph. graph and the newly constructed graph.
""" """
if hasattr(x, 'name') and x.name is not None: if hasattr(x, 'name') and x.name is not None:
nw_name = x.name + tag nw_name = x.name + tag
...@@ -117,21 +119,28 @@ class until(object): ...@@ -117,21 +119,28 @@ class until(object):
between the condition and the list of outputs ( unless we enforce and between the condition and the list of outputs ( unless we enforce and
order, but since this was not impose up to know it can make quite a bit order, but since this was not impose up to know it can make quite a bit
of code to fail). of code to fail).
""" """
def __init__(self, condition): def __init__(self, condition):
self.condition = tensor.as_tensor_variable(condition) self.condition = tensor.as_tensor_variable(condition)
assert self.condition.ndim == 0 assert self.condition.ndim == 0
def traverse(out, x, x_copy, d, visited=None): def traverse(out, x, x_copy, d, visited=None):
''' Function used by scan to parse the tree and figure out which nodes """
it needs to replace. There are two options : Function used by scan to parse the tree and figure out which nodes
it needs to replace.
There are two options :
1) x and x_copy or on host, then you would replace x with x_copy 1) x and x_copy or on host, then you would replace x with x_copy
2) x is on gpu, x_copy on host, then you need to replace 2) x is on gpu, x_copy on host, then you need to replace
host_from_gpu(x) with x_copy host_from_gpu(x) with x_copy
This happens because initially shared variables are on GPU .. which is This happens because initially shared variables are on GPU... which is
fine for the main computational graph but confuses things a bit for the fine for the main computational graph but confuses things a bit for the
inner graph of scan ''' inner graph of scan.
"""
# ``visited`` is a set of nodes that are already known and don't need to be # ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs. # checked again, speeding up the traversal of multiply-connected graphs.
# if a ``visited`` set is given, it will be updated in-place so the callee # if a ``visited`` set is given, it will be updated in-place so the callee
...@@ -191,25 +200,25 @@ def clone(output, ...@@ -191,25 +200,25 @@ def clone(output,
share_inputs=True, share_inputs=True,
copy_inputs=DEPRECATED_ARG): copy_inputs=DEPRECATED_ARG):
""" """
Function that allows replacing subgraphs of a computational Function that allows replacing subgraphs of a computational graph.
graph. It returns a copy of the initial subgraph with the corresponding
It returns a copy of the initial subgraph with the corresponding
substitutions. substitutions.
:type output: Theano Variables (or Theano expressions) Parameters
:param outputs: Theano expression that represents the computational ----------
graph output : Theano Variables (or Theano expressions)
Theano expression that represents the computational graph.
replace : dict
Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool
If True, use the same inputs (and shared variables) as the original
graph. If False, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
copy_inputs
Deprecated, use share_inputs.
:type replace: dict
:param replace: dictionary describing which subgraphs should be
replaced by what
:type share_inputs: bool
:param share_inputs: If True, use the same inputs (and shared variables)
as the original graph. If False, clone them. Note that cloned
shared variables still use the same underlying storage, so they
will always have the same value.
:param copy_inputs: deprecated, use share_inputs.
""" """
if copy_inputs is not DEPRECATED_ARG: if copy_inputs is not DEPRECATED_ARG:
warnings.warn('In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`') warnings.warn('In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`')
...@@ -251,7 +260,7 @@ def get_updates_and_outputs(ls): ...@@ -251,7 +260,7 @@ def get_updates_and_outputs(ls):
""" """
This function tries to recognize the updates OrderedDict, the This function tries to recognize the updates OrderedDict, the
list of outputs and the stopping condition returned by the list of outputs and the stopping condition returned by the
lambda expression and arrange them in a predefined order lambda expression and arrange them in a predefined order.
WRITEME: what is the type of ls? how is it formatted? WRITEME: what is the type of ls? how is it formatted?
if it's not in the predefined order already, how does if it's not in the predefined order already, how does
...@@ -297,6 +306,7 @@ def get_updates_and_outputs(ls): ...@@ -297,6 +306,7 @@ def get_updates_and_outputs(ls):
Return True iff `x` is made only of lists, tuples, dictionaries, Theano Return True iff `x` is made only of lists, tuples, dictionaries, Theano
variables or `theano.scan_module.until` objects. variables or `theano.scan_module.until` objects.
""" """
# Is `x` a container we can iterate on? # Is `x` a container we can iterate on?
iter_on = None iter_on = None
...@@ -390,10 +400,11 @@ def isNaN_or_Inf_or_None(x): ...@@ -390,10 +400,11 @@ def isNaN_or_Inf_or_None(x):
def expand(tensor_var, size): def expand(tensor_var, size):
''' """
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..) Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding 0s at the end of the tensor. by adding 0s at the end of the tensor.
'''
"""
# Corner case that I might use in an optimization # Corner case that I might use in an optimization
if size == 0: if size == 0:
return tensor_var return tensor_var
...@@ -406,7 +417,7 @@ def expand(tensor_var, size): ...@@ -406,7 +417,7 @@ def expand(tensor_var, size):
def equal_computations(xs, ys, in_xs=None, in_ys=None): def equal_computations(xs, ys, in_xs=None, in_ys=None):
'''Checks if Theano graphs represent the same computations. """Checks if Theano graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The The two lists `xs`, `ys` should have the same number of entries. The
function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)` function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
...@@ -420,7 +431,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -420,7 +431,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
`ys`, but also represent subgraphs of a computational graph in `xs` `ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`. or `ys`.
''' """
assert len(xs) == len(ys) assert len(xs) == len(ys)
if in_xs is None: if in_xs is None:
in_xs = [] in_xs = []
...@@ -460,14 +471,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -460,14 +471,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Explore the two graphs, in parallel, depth first, comparing the nodes # Explore the two graphs, in parallel, depth first, comparing the nodes
# along the way for equality. # along the way for equality.
def compare_nodes(nd_x, nd_y, common, different): def compare_nodes(nd_x, nd_y, common, different):
''' Compare two nodes to determine if they perform equal computation. """
Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal by ensuring that the inputs themselves are the result of equal
computation. computation.
NOTE : This function relies on the variable common to cache NOTE : This function relies on the variable common to cache
results to be more efficient. results to be more efficient.
'''
"""
if nd_x.op != nd_y.op: if nd_x.op != nd_y.op:
return False return False
...@@ -537,13 +550,14 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -537,13 +550,14 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
''' """
Compute the shape of the outputs given the shape of the inputs Compute the shape of the outputs given the shape of the inputs of a theano
of a theano graph. graph.
We do it this way to avoid compiling the inner function just to get We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function. the shape. Changes to ShapeFeature could require changes in this function.
'''
"""
# We use a ShapeFeature because it has all the necessary logic # We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we # inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will # let it initialize itself with an empty fgraph, otherwise we will
...@@ -560,10 +574,10 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -560,10 +574,10 @@ def infer_shape(outs, inputs, input_shapes):
shape_feature.set_shape(inp, inp_shp) shape_feature.set_shape(inp, inp_shp)
def local_traverse(out): def local_traverse(out):
''' """
Go back in the graph, from out, adding computable shapes to shape_of. Go back in the graph, from out, adding computable shapes to shape_of.
'''
"""
if out in shape_feature.shape_of: if out in shape_feature.shape_of:
# Its shape is already known # Its shape is already known
return return
...@@ -589,14 +603,18 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -589,14 +603,18 @@ def infer_shape(outs, inputs, input_shapes):
class Validator(object): class Validator(object):
def __init__(self, valid=None, invalid=None, valid_equivalent=None): """
''' Check if variables can be expressed without using variables in invalid.
Check if variables can be expressed without using variables in invalid.
Parameters
----------
valid_equivalent
Provides a dictionary mapping some invalid variables to valid ones that
can be used instead.
init_valid_equivalent provides a dictionary mapping some invalid """
variables to valid ones that can be used instead.
'''
def __init__(self, valid=None, invalid=None, valid_equivalent=None):
if valid is None: if valid is None:
valid = [] valid = []
if invalid is None: if invalid is None:
...@@ -616,13 +634,14 @@ class Validator(object): ...@@ -616,13 +634,14 @@ class Validator(object):
self.invalid.update(list(valid_equivalent.keys())) self.invalid.update(list(valid_equivalent.keys()))
def check(self, out): def check(self, out):
''' """
Go backwards in the graph, from out, and check if out is valid. Go backwards in the graph, from out, and check if out is valid.
If out is a valid node, (out, True) is returned. If out is a valid node, (out, True) is returned.
If out is not valid, but has an equivalent e, (e, False) is returned. If out is not valid, but has an equivalent e, (e, False) is returned.
If out is not valid and has no equivalent, None is returned. If out is not valid and has no equivalent, None is returned.
'''
"""
if out in self.valid: if out in self.valid:
return out, True return out, True
elif out in self.valid_equivalent: elif out in self.valid_equivalent:
...@@ -667,12 +686,13 @@ class Validator(object): ...@@ -667,12 +686,13 @@ class Validator(object):
def scan_can_remove_outs(op, out_idxs): def scan_can_remove_outs(op, out_idxs):
''' """
Looks at all outputs defined by indices ``out_idxs`` and see whom can be Looks at all outputs defined by indices ``out_idxs`` and see whom can be
removed from the scan op without affecting the rest. Return two lists, removed from the scan op without affecting the rest. Return two lists,
the first one with the indices of outs that can be removed, the second the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed. with the outputs that can not be removed.
'''
"""
non_removable = [o for i, o in enumerate(op.outputs) if i not in non_removable = [o for i, o in enumerate(op.outputs) if i not in
out_idxs] out_idxs]
required_inputs = gof.graph.inputs(non_removable) required_inputs = gof.graph.inputs(non_removable)
...@@ -706,7 +726,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -706,7 +726,7 @@ def scan_can_remove_outs(op, out_idxs):
def compress_outs(op, not_required, inputs): def compress_outs(op, not_required, inputs):
''' """
Helpful function that gets a Scan op, a list of indices indicating Helpful function that gets a Scan op, a list of indices indicating
which outputs are not required anymore and should be removed, and which outputs are not required anymore and should be removed, and
a list of inputs to the apply node corresponding to the scan op and a list of inputs to the apply node corresponding to the scan op and
...@@ -714,7 +734,8 @@ def compress_outs(op, not_required, inputs): ...@@ -714,7 +734,8 @@ def compress_outs(op, not_required, inputs):
the indicated outputs are eliminated. Note that eliminating an output the indicated outputs are eliminated. Note that eliminating an output
means removing its inputs from the inner funciton and from the means removing its inputs from the inner funciton and from the
node inputs, and changing the dictionary. node inputs, and changing the dictionary.
'''
"""
info = OrderedDict() info = OrderedDict()
info['tap_array'] = [] info['tap_array'] = []
info['n_seqs'] = op.info['n_seqs'] info['n_seqs'] = op.info['n_seqs']
...@@ -852,6 +873,7 @@ def compress_outs(op, not_required, inputs): ...@@ -852,6 +873,7 @@ def compress_outs(op, not_required, inputs):
def find_up(l_node, f_node): def find_up(l_node, f_node):
r""" r"""
Goes up in the graph and returns True if a node in nodes is found. Goes up in the graph and returns True if a node in nodes is found.
""" """
if isinstance(l_node, gof.Apply): if isinstance(l_node, gof.Apply):
l_outs = l_node.outputs l_outs = l_node.outputs
...@@ -866,8 +888,9 @@ def reconstruct_graph(inputs, outputs, tag=None): ...@@ -866,8 +888,9 @@ def reconstruct_graph(inputs, outputs, tag=None):
""" """
Different interface to clone, that allows you to pass inputs. Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with Compared to clone, this method always replaces the inputs with
new variables of the same type, and returns those ( in the same new variables of the same type, and returns those (in the same
order as the original inputs). order as the original inputs).
""" """
if tag is None: if tag is None:
tag = '' tag = ''
...@@ -885,7 +908,11 @@ def reconstruct_graph(inputs, outputs, tag=None): ...@@ -885,7 +908,11 @@ def reconstruct_graph(inputs, outputs, tag=None):
class scan_args(object): class scan_args(object):
"""Parses the inputs and outputs of scan in an easy to manipulate format""" """
Parses the inputs and outputs of scan in an easy to manipulate format.
"""
def __init__(self, outer_inputs, outer_outputs, def __init__(self, outer_inputs, outer_outputs,
_inner_inputs, _inner_outputs, info): _inner_inputs, _inner_outputs, info):
self.n_steps = outer_inputs[0] self.n_steps = outer_inputs[0]
...@@ -1070,17 +1097,22 @@ class scan_args(object): ...@@ -1070,17 +1097,22 @@ class scan_args(object):
def forced_replace(out, x, y): def forced_replace(out, x, y):
""" """
:param out: Theano Variable Check all internal values of the graph that compute the variable ``out``
:param x: Theano Variable for occurrences of values identical with ``x``. If such occurrences are
:param y: Theano Variable encountered then they are replaced with variable ``y``.
This function checks all internal values of the graph that computes the Parameters
variable ``out`` for occurances of values identical with ``x``. If such ----------
occurances are encountered then they are replaced with variable ``y``. out : Theano Variable
For example: x : Theano Variable
out := sigmoid(wu)*(1-sigmoid(wu)) y : Theano Variable
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y) Examples
--------
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
""" """
if out is None: if out is None:
return None return None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论