提交 96676ed5 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

[scan][doc][coding-style] re-arranged the documentation of scan parameters

上级 b15fadcc
......@@ -268,164 +268,250 @@ def foldr( fn
# Yes, actually it will be exactly 2 ( if there are no other constraints)
def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
n_steps = None, truncate_gradient = -1, go_backwards = False,
mode = None, name = None):
"""Function that constructs and applies a Scan op
def scan( fn
, sequences = None
, outputs_info = None
, non_sequences = None
, n_steps = None
, truncate_gradient = -1
, go_backwards = False
, mode = None
, name = None ):
"""
This function constructs and applies a Scan op to the provided
arguments.
:param fn:
Function that describes the operations involved in one step of scan
Given variables representing all the slices of input and past values of
outputs and other non sequences parameters, ``fn`` should produce
variables describing the output of one time step of scan. The order in
which the argument to this function are given is very important. You
should have the following order:
* all time slices of the first sequence (as given in the
``sequences`` list) ordered in the same fashion as the time taps provided
* all time slices of the second sequence (as given in the
``sequences`` list) ordered in the same fashion as the time taps provided
``fn`` is a function that describes the operations involved in one step
of ``scan``. ``fn`` should construct variables describing the output of
one iteration step. It should expect as input theano variables
representing all the time slices of the input sequences and outputs,
and all other arguments given to scan as ``non_sequences``. The order
in which scan passes this variables to ``fn`` is the following :
* all time slices of the first sequence
* all time slices of the second sequence
* ...
* all time slices of the first output (as given in the
``initial_state`` list) ordered in the same fashion as the time taps provided
* all time slices of the second otuput (as given in the
``initial_state`` list) ordered in the same fashion as the time taps provided
* all time slices of the last sequence
* all time slices of the first output
* all time slices of the second otuput
* ...
* all other parameters over which scan doesn't iterate ordered accordingly
If you are using shared variables over which you do not want to iterate,
you do not need to provide them as arguments to ``fn``, though you can if you
wish so. The function should return the outputs after each step plus the updates
for any of the shared variables. You can either return only outputs or only
updates. If you have both outputs and updates the function should return
them as a tuple : (outputs, updates) or (updates, outputs).
* all time slices of the last output
* all other arguments (the list given as `non_sequences` to
scan)
The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the sane
as the order of ``output_info``. For any sequence or output the
order of the time slices is the same as the order of the time
taps provided. For example if one writes the following :
.. code-block:: python
scan(fn, sequences = [ dict( Sequence1, taps = [-3,2,-1])
, Sequence2
, dict( Sequence3, taps = 3) ]
, outputs_info = [ dict( Output1, taps = [-3,-5])
, dict( Output2, taps = None)
, Output3 ]
, non_sequences = [ Argument1, Argument 2])
``fn`` should expect the following arguments in this given order:
#. ``Sequence1[t-3]``
#. ``Sequence1[t+2]``
#. ``Sequence1[t-1]``
#. ``Sequence2[t]``
#. ``Sequence3[t+3]``
#. ``Output1[t-3]``
#. ``Output1[t-5]``
#. ``Output3[t-1]``
#. ``Argument1``
#. ``Argument2``
The list of ``non_sequences`` can also contain shared variables
used in the function, though ``scan`` is able to figure those
out on its own so they can be skipped. For the clarity of the
code we recommand though to provide them to scan.
The function is expected to return two things. One is a list of
outputs ordered in the same order as ``outputs_info``, with the
difference that there should be only one output variable per
output initial state (even if no tap value is used). Secondly
`fn` should return an update dictionary ( that tells how to
update any shared variable after each iteration ste). The
dictionary can optionally be given as a list of tuples. There is
no constraint on the order of these two list, ``fn`` can return
either ``(outputs_list, update_dictionary)`` or ``(update_dictionary,
outputs_list)`` or just one of the two (in case the other is
empty).
Outputs can be just a theano expression if you have only one output or
a list of theano expressions. Updates can be given either as a list of tuples or
as a dictionary. If you have a list of outputs, the order of these
should match that of their ``initial_states``.
:param sequences:
list of Theano variables or dictionaries containing Theano variables over which
scan needs to iterate. The reason you might want to wrap a certain Theano
variable in a dictionary is to provide auxiliary information about how to iterate
over that variable. For example this is how you specify that you want to use
several time slices of this sequence at each iteration step. The dictionary
should have the following keys :
* ``input`` -- Theano variable representing the sequence
* ``taps`` -- temporal taps to use for this sequence. They are given as a list
of ints, where a value ``k`` means that at iteration step ``t`` scan needs to
provide also the slice ``t+k`` The order in which you provide these int values
here is the same order in which the slices will be provided to ``fn``.
If you do not wrap a variable around a dictionary, scan will do it for you, under
the assumption that you use only one slice, defined as a tap of offset 0. This
means that at step ``t`` scan will provide the slice at position ``t``.
``sequences`` is the list of Theano variables or dictionaries
describing the sequences ``scan`` has to iterate over. If a
sequence is given as wrapped in a dictionary a set of optional
information can be provided about the sequence. The dictionary
should have the following keys:
* ``input`` (*mandatory*) -- Theano variable representing the
sequence.
* ``taps`` -- Temporal taps of the sequence required by ``fn``.
They are provided as a list of integers, where a value ``k`` impiles
that at iteration step ``t`` scan will pass to ``fn`` the slice
``t+k``. Default value is ``[0]``
Any Theano variable in the list ``sequences`` is automatically
wrapped into a dictionary where ``taps`` is set to ``[0]``
:param outputs_info:
list of Theano variables or dictionaries containing Theano variables used
to initialize the outputs of scan. As before (for ``sequences``) the reason
you would wrap a Theano variable in a dictionary is to provide additional
information about how scan should deal with that specific output. The dictionary
should contain the following keys:
* ``initial`` -- Theano variable containing the initial state of the output
* ``taps`` -- temporal taps to use for this output. The taps are given as a
list of ints (only negative .. since you can not use future values of outputs),
with the same meaning as for ``sequences`` (see above).
* ``inplace`` -- theano variable pointing to one of the input sequences; this
flag tells scan that the output should be computed in the memory space occupied
by that input sequence. Note that scan will only do this if allowed by the
rest of your computational graph and if you are not using past taps of the
input.
* ``return_steps`` how many steps to return from your output. If not given, or
0 scan will return all steps, otherwise it will return the last ``return_steps``.
Note that if you set this to something else then 0, scan will try to be smart
about the amount of memory it allocates for a given input.
If the function applied recursively uses only the
previous value of the output, the initial state should have
same shape as one time step of the output; otherwise, the initial state
should have the same number of dimension as output. This is easily
understood through an example. For computing ``y[t]`` let us assume that we
need ``y[t-1]``, ``y[t-2]`` and ``y[t-4]``. Through an abuse of
notation, when ``t = 0``, we would need values for ``y[-1]``, ``y[-2]``
and ``y[-4]``. These values are provided by the initial state of ``y``,
which should have same number of dimension as ``y``, where the first
dimension should be large enough to cover all the required past values, which in
this case is 4. If ``init_y`` is the variable containing the initial state
of ``y``, then ``init_y[0]`` corresponds to ``y[-4]``, ``init_y[1]``
corresponds to ``y[-3]``, ``init_y[2]`` corresponds to ``y[-2]``,
``init_y[3]`` corresponds to ``y[-1]``. The default behaviour of scan is
the following :
* if you do not wrap an output in a dictionary, scan will wrap it for you
assuming that you use only the last step of the output ( i.e. it makes your tap
value list equal to [-1]) and that it is not computed inplace
* if you wrap an output in a dictionary and you do not provide any taps but
you provide an initial state it will assume that you are using only a tap value
of -1
* if you wrap an output in a dictionary but you do not provide any initial state,
it assumes that you are not using any form of taps
* if you provide a ``None`` instead of a variable or a dictionary scan assumes
that you will not use any taps for this output (this would be the case for map)
If you did not provide any information for your outputs, scan will assume by
default that you are not using any taps for any of the outputs. If you provide
information for just a subset of outputs, scan will not know to which outputs
these correspond and will raise an error.
``outputs_info`` is the list of Theano variables or dictionaries
describing the initial state of the outputs computed
recurrently. When this initial states are given as dictionary
optional information can be provided about the output corresponding
to these initial states. The dictionary should have the following
keys:
* ``initial`` -- Theano variable that represents the initial
state of a given output. In case the output is not computed
recursively (think of a map) and does not require a initial
state this field can be skiped. Given that only the previous
time step of the output is used by ``fn`` the initial state
should have the same shape as the output. If multiple time
taps are used, the initial state should have one extra
dimension that should cover all the possible taps. For example
if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0,
``fn`` will require (by an abuse of notation) ``output[-5]``,
``output[-2]`` and ``output[-1]``. This will be given by
the initial state, which in this case should have the shape
(5,)+output.shape. If this variable containing the initial
state is called ``init_y`` then ``init_y[0]`` *corresponds to*
``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``,
``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]``
coresponds to ``output[-2]``, ``init_y[4]`` corresponds to
``output[-1]``. While this order might seem strange, it comes
natural from splitting an array at a given point. Assume that
we have a array ``x``, and we choose ``k`` to be time step
``0``. Then our initial state would be ``x[:k]``, while the
output will be ``x[k:]``. Looking at this split, elements in
``x[:k]`` are ordered exactly like those in ``init_y``.
* ``taps`` -- Temporal taps of the output that will be pass to
``fn``. They are provided as a list of *negative* integers,
where a value ``k`` implies that at iteration step ``t`` scan will
pass to ``fn`` the slice ``t+k``.
* ``inplace`` -- One of the Theano variables provided as
``sequences``. ``scan`` will try to compute this output *in
place* of the provided input *iff* it respects the following
constraints:
* There is no other output that is denied to be computed in
place for whatever reason.
* ``fn`` is not using past taps of the input sequence that
will get overwritten by the output
* ``return_steps`` -- Integer representing the number of steps
to return for the current steps. For example, if ``k`` is
provided, ``scan`` will return ``output[-k:]``. This is meant as a
hint, based on ``k`` and the past taps of the outputs used, scan
can be smart about the amount of memory it requires to store
intermidiate results. If not given, or ``0``, ``scan`` will return
all computed steps.
* ``store_steps`` -- Integer representing the number of
intermidiate steps ``scan`` should use for a given output. Use
this key only if you really know what you are doing. In general
is recommendat to let scan decide for you the ammount of memory
it should use.
``scan`` will follow this logic if partial information is given:
* If an output is not wrapped in a dictionary, ``scan`` will wrap
it in one assuming that you use only the last step of the output
(i.e. it makes your tap value list equal to [-1]) and that it is
not computed inplace.
* If you wrap an output in a dictionary and you do not provide any
taps but you provide an initial state it will assume that you are
using only a tap value of -1.
* If you wrap an output in a dictionary but you do not provide any
initial state, it assumes that you are not using any form of
taps.
* If you provide a ``None`` instead of a variable or a dictionary
``scan`` assumes that you will not use any taps for this output
(like for example in case of a map)
If ``outputs_info`` is an empty list or None, ``scan`` assumes
that no tap is used for any of the otuputs. If information is
provided just for a subset of the outputs an exception is
raised (because there is no convention on how scan should map
the provided information to the outputs of ``fn``)
:param non_sequences:
Parameters over which scan should not iterate. These parameters are
given at each time step to the function applied recursively.
``non_sequences`` is the list of arguments that are passed to
``fn`` at each steps. Once can opt to exclude shared variables
used in ``fn`` from this list.
:param n_steps:
Number of steps to iterate. If the input sequences are not long enough, scan
will produce a warning and run only for the maximal amount of steps allowed by
the input sequences. If the value is 0, the outputs will have 0 rows. If the
value is negative, scan will run backwards (or if the flag go_backwards is
already set to true it will run forward in time). If n_steps is not provided,
or evaluetes to None, inf or nan, scan will figure out the maximal amount of
steps it can run given the input sequences and do that.
``n_steps`` is the number of steps to iterate given as an int
or Theano scalar. If any of the input sequences do not have
enough elements, scan will produce a warning and run only for
the maximal amount of steps it can. If the *value is 0* the
outputs will have *0 rows*. If the value is negative, ``scan``
run backwards in time. If the ``go_backwards`` flag is already
set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n stpes is not provided, or evaluates to ``None``,
``inf`` or ``NaN``, ``scan`` will figure out the amount of
steps it should run given its input sequences.
:param truncate_gradient:
Number of steps to use in truncated BPTT. If you compute gradients
through a scan op, they are computed using backpropagation through time.
By providing a different value then -1, you choose to use truncated BPTT
instead of classical BPTT, where you only do ``truncate_gradient``
number of steps.
``truncate_gradient`` is the number of steps to use in truncated
BPTT. If you compute gradients through a scan op, they are
computed using backpropagation through time. By providing a
different value then -1, you choose to use truncated BPTT instead
of classical BPTT, where you go for only ``truncate_gradient``
number of steps back in time.
:param go_backwards:
Flag indicating if you should go backwards through the sequences ( if you
think as the sequences being indexed by time, this would mean go backwards
in time)
``go_backwards`` is a flag indicating if ``scan`` should go
backwards through the sequences. If you think of each sequence
as indexed by time, making this flag True would mean that
``scan`` goes back in time, namely that for any sequence it
starts from the end and goes towards 0.
:param name:
The name of the theano function compiled by the Scan op. It will show in the
profiler output.
When profiling ``scan`` it is crucial to provide a name for any
instance of ``scan``. The profiler will produce an overall
profile of your code as well as profiles for doing one iteration
step for each instance of ``scan``. The ``name`` of the instance is
how you differentiate between all these profiles.
:param mode:
The mode used when compiling the theano function in the Scan op.
If None, it will use the config mode. If None and the config mode is set to
profile mode, it we will create a new instance of the ProfileMode in order
to compute the timming correctly.
If no new instance is created the time spend in Scan will show up twice in the
profiling, once as the time taken by scan, and the second time as the time
taken by the ops inside scan. This will be even worse for multiple cascading
scans.
The new profiler instance will be printed when python exits.
It is recommended to leave this argument to None, especially
when profiling ``scan`` (otherwise the results are not going to
be accurate). If you prefer the computations of one step os
``scan`` to be done differently then the entire function set
this parameters (see ``theano.function`` for details about
possible values and their meaning).
:rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the
outputs of scan. ``updates`` is a dictionary specifying the
outputs of ``scan`` (in the same order as in
``outputs_info``. ``updates`` is a dictionary specifying the
updates rules for all shared variables used in the scan
operation; this dictionary should be pass to ``theano.function``
operation. This dictionary should be pass to ``theano.function``
when you compile your function.
"""
# General observation : this code is executed only once, at creation
# of the computational graph, so we don't yet need to be smart about
# General observation : this code is executed only once, at creation
# of the computational graph, so we don't yet need to be smart about
# anything ( to speed things up)
# check if inputs are just single variables instead of lists
......@@ -449,7 +535,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
if type(n_steps) in (float,int):
n_fixed_steps = int(n_steps)
n_fixed_steps = int(n_steps)
else:
# also check if this value happens to be a constant,
# then we could do the same
......@@ -460,16 +546,16 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# compute number of sequences and number of outputs
n_seqs = len(seqs)
n_outs = len(outs_info)
# initialize the inplace map, sequences map and
# initialize the inplace map, sequences map and
# outputs map
''' Details:
The scan op identifies different properties attached
to input tensors by their order in the input list.
These maps ( inplace, sequence_taps, output_taps,
store_steps, return_steps) go from the index of an input to
to input tensors by their order in the input list.
These maps ( inplace, sequence_taps, output_taps,
store_steps, return_steps) go from the index of an input to
its properties. Note that inputs are always first, followed
by outputs. Since we always know the number of inputs we
index the outputs from 0 ( so sometimes you will need to
by outputs. Since we always know the number of inputs we
index the outputs from 0 ( so sometimes you will need to
do something like outputs_taps[i-n_ins]
'''
inplace_map = {}
......@@ -498,13 +584,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# which would indicate that the sequence was provided but
# not used by the internal function; Only if the user has
# not provided anything add the defaul [0]
# Possible reason to provide a squence and not use it is
# Possible reason to provide a squence and not use it is
# if you want to compute the output
# inplace of this input; it is a very unlikely behaviour but
# we do want to cover it for completeness
if not seqs[i].has_key('taps'):
seqs[i][taps] = [0]
# Now that our input is well behaved, collect the taps in the
# Now that our input is well behaved, collect the taps in the
# sequences_taps map that we will use later in the body of scan
# since inputs will be just tensors there
if seqs[i].get('taps',None):
......@@ -514,14 +600,14 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# in one and in the same pass create a init_outs_taps dictionary and a inplace map
for i in xrange(n_outs):
if outs_info[i]:
# If output is a dictionary, collect the number of steps the
# user would like scan to return
# If output is a dictionary, collect the number of steps the
# user would like scan to return
if type(outs_info[i]) == dict:
if outs_info[i].get('return_steps', None):
return_steps[i] = outs_info[i]['return_steps']
# If you provide the number of steps to store internally,
# (not advocated in the user documentation), then also
# make sure you are returning only those number of steps
# (not advocated in the user documentation), then also
# make sure you are returning only those number of steps
if outs_info[i].get('store_steps', None):
store_steps += [outs_info[i].get('store_steps',None)]
return_steps[i] = outs_info[i].get('store_steps',None)
......@@ -540,11 +626,11 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
(outs_info[i].get('taps',None)):
raise ValueError('If you are using slices of an output you need to '\
'provide a initial state for it', outs_info[i])
# if there is an intial state but no tap, we will add the default value
# for taps, namely [-1] ( previous value); not that this will happen
# even though you have provided for taps the value None, which is a bit
# strange (why would one provide an initial state but tell scan not to
# use it ? ), just that in that case we will throw in a warning message
# if there is an intial state but no tap, we will add the default value
# for taps, namely [-1] ( previous value); not that this will happen
# even though you have provided for taps the value None, which is a bit
# strange (why would one provide an initial state but tell scan not to
# use it ? ), just that in that case we will throw in a warning message
# pointing out this inconsistency
elif outs_info[i].get('initial',None) and \
( not outs_info[i].get('taps',None)):
......@@ -556,18 +642,18 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
'provide the initial state')
outs_info[i]['taps'] = [-1]
else:
# if the output is a None then replace it with an empty dictionary for
# easing up dealing with this case later one ( we can directly call .has_key
# if the output is a None then replace it with an empty dictionary for
# easing up dealing with this case later one ( we can directly call .has_key
# and things like this
outs_info[i] = dict()
store_steps += [0]
if outs_info[i].get('taps', None):
# Create a separate outputs_taps dictionary with all the outputs taps; This
# Create a separate outputs_taps dictionary with all the outputs taps; This
# is how the Scan Op expects this information, separeted from the variables
outputs_taps[i] = outs_info[i]['taps']
if outs_info[i].get('inplace', None):
# The same is true for the inplace info; it has to go into a separate
# The same is true for the inplace info; it has to go into a separate
# dictionary based on index; Note that the input we're replacing should also
# come as an index, therefore we have to look for it at this point
found = None
......@@ -575,7 +661,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if seqs[k].get('input', None) == outs_info[i].get('inplace',None):
found = k
if found != None:
# NOTE : inplace_map is identical to destroy_map, i.e. it tells what
# NOTE : inplace_map is identical to destroy_map, i.e. it tells what
# output is computed inplace of what input !!
inplace_map[i] = found
else:
......@@ -602,12 +688,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# create one slice of the input
'''
Later on, if we decide not to use scan because we are going
for just one step, it makes things easier if we compute the
correct outputs here. This way we can use the output of the
for just one step, it makes things easier if we compute the
correct outputs here. This way we can use the output of the
lambda expression directly to replace the output of scan.
If not we need to use copies, that will be replaced at each
frame by the corresponding slice
If not we need to use copies, that will be replaced at each
frame by the corresponding slice
'''
if n_fixed_steps not in [1,-1]:
nw_slice = seq['input'][0].type()
......@@ -625,10 +711,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
else:
nw_slice.name = seq['input'].name + '[t%d]'%seq['taps'][k]
args.append(nw_slice)
# Specify to whom this slice belongs
# Specify to whom this slice belongs
slice_to_seqs.append(i)
# Any slice is not a shared variable, even though the sequence
# from where we pick the slices is shared, therefore we should
# Any slice is not a shared variable, even though the sequence
# from where we pick the slices is shared, therefore we should
# increase the number of notshared inputs to the dummy function
# by the number of slices
dummy_notshared_ins += len(seq['taps'])
......@@ -636,7 +722,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
for i,init_out in enumerate(outs_info):
# Note that our convention dictates that if an output uses
# just the previous time step, as a initial state we will only provide
# a tensor of the same dimension as one time step; This makes code
# a tensor of the same dimension as one time step; This makes code
# much cleaner for those who do not use taps. Otherwise they would
# always had to shape_pad_left the initial state .. which is ugly
if init_out.get('taps', None) == [-1]:
......@@ -647,9 +733,9 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# Added name to slices for debugging and pretty printing
if init_out['initial'].name:
args[-1].name = init_out['initial'].name+'[t-1]'
# we need to specify in slice_seqs to which output this
# slice belongs; Because we might get confused afterwards
# if a number is an index of a sequence or an output, and
# we need to specify in slice_seqs to which output this
# slice belongs; Because we might get confused afterwards
# if a number is an index of a sequence or an output, and
# because we do not want to create yet another list, we will
# add the number of sequences + the current output. This makes
# decoding easy and spares us from writing a lot of lines
......@@ -682,11 +768,11 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# add as many slices as there are taps
dummy_notshared_init_outs += len(init_out['taps'])
#NOTE: there is another case, in which we do not want to provide any previous
# value of the output to the inner case; in this case we do not have to do
# value of the output to the inner case; in this case we do not have to do
# anything ..
# remove shared variables from the non sequences list
# such that we can compile the function ( the user has the option to add them when
# such that we can compile the function ( the user has the option to add them when
# writing scan, because in some situations this might make the code more readable)
notshared_other_args = []
for non_seq in non_seqs:
......@@ -707,7 +793,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
outputs_updates = fn(*args)
# The code that follows tries to be as flexible as possible allowing the
# The code that follows tries to be as flexible as possible allowing the
# user to return the output and updates in any order, and giving the updates
# however he wants ( as a dictionary or a list o pairs ..)
# Is there a way to compress all this by writing it in a more python/functional way?
......@@ -747,7 +833,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
outputs = outputs_updates
updates = {}
# in case you return a tuple .. convert it to a list (there are certain
# in case you return a tuple .. convert it to a list (there are certain
# operation that are not permited on tuples, like element assignment)
outputs = list(outputs)
......@@ -765,12 +851,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# so we can do stuff as unoptimal as we wish ]
if n_fixed_steps in [-1,1]:
''' We do have a special case here, namely is so might happen that
whatever we have in dummy_args is not sufficient to compile the
function( i.e. missing inputs). Furthermore we might not even need
whatever we have in dummy_args is not sufficient to compile the
function( i.e. missing inputs). Furthermore we might not even need
to compile the function here for this special case. But due to the
way I wrote the code is easier to have a compiled function here
that I can ignore later. Plus it is easier this way to take care
of shared variables with non-default updates. Therefore only for
way I wrote the code is easier to have a compiled function here
that I can ignore later. Plus it is easier this way to take care
of shared variables with non-default updates. Therefore only for
this case I need to use gof.graph.inputs to look for the real inputs
so that I can compile the function. RP '''
dummy_f = function(filter(lambda x: isinstance(x, gof.Variable) and \
......@@ -802,12 +888,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# other updates :
for i in xrange(n_outs):
outs_info += [ dict() ]
# we also need to re-initialize the store_steps list to match the
# we also need to re-initialize the store_steps list to match the
# number of outputs
store_steps = [ 0 for i in xrange(n_outs)]
else:
# Otherwise there is a bit of confusion, since Scan works on the index of
# Otherwise there is a bit of confusion, since Scan works on the index of
# a sequence /output. There are maybe corner cases that could be added here
# or defult behaviour ( like always add the extra outputs at the end !?)
# But I did not bother implementing this, I leave it to the user to clearly
......@@ -832,7 +918,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
fromIdx = dummy_notshared_ins + dummy_notshared_init_outs
copy_map = {}
for input in dummy_f.maker.expanded_inputs[fromIdx:] :
# If input is a shared variable that gets updated, then
# If input is a shared variable that gets updated, then
# this shared variable will be an output of our inner function
if isinstance(input.variable, SharedVariable) and input.update:
# Create a copy of it
......@@ -857,8 +943,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# inner_fn_shared_ins_idx stores where we stop having shared variables with updates
inner_fn_shared_ins_idx = len(inner_fn_inputs) - inner_fn_notshared_ins_idx
# Now that we took out the shared variables that have an update rule
# we need to take care of all the other shared variables
# Now that we took out the shared variables that have an update rule
# we need to take care of all the other shared variables
for input in dummy_f.maker.expanded_inputs[fromIdx:] :
# make sure that we do not add the same shared variable twice
if isinstance(input.variable, SharedVariable) and not input.update:
......@@ -871,14 +957,14 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
givens[input.variable] = inner_fn_inputs[-1]
copy_map[inner_fn_inputs[-1]] = input.variable
elif not isinstance(input.variable, SharedVariable):
# also add the normal tensor that are non sequences at the
# also add the normal tensor that are non sequences at the
# end of the inputs intertwingled with the shared variables
inner_fn_inputs.append(input.variable)
# If we haven't provided a number of steps nor did we provide a sequence
# If we haven't provided a number of steps nor did we provide a sequence
# scan will not know how long to iterate
if (n_steps == None or n_steps == numpy.inf or n_steps == numpy.nan) and n_seqs == 0 :
if (n_steps == None or n_steps == numpy.inf or n_steps == numpy.nan) and n_seqs == 0 :
raise ValueError('Scan does not know for how many steps to iterate. '
'You need to provide the number of steps through the '
' ``n_steps`` argument if you do not iterate over any sequence')
......@@ -925,19 +1011,19 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if not type(values) in (tuple, list):
values = [values]
# take out the updates of shared variable and build the dictionary
# take out the updates of shared variable and build the dictionary
# that tells what to update and with what value
for val in update_map.keys():
update_map[val] = values [ update_map[val] ]
# Now we need to check the values returned
# if it just one strip the list around it
# if it just one strip the list around it
if n_outs == 1:
# if we need to return just one step or several steps
# note that when we return one step we have two cases, in
# note that when we return one step we have two cases, in
# the first one store_steps is set to 1, case in which we don't
# need to take a slice of the output (is already of the right
# dimension) and case 2 when we store more then one step,
# need to take a slice of the output (is already of the right
# dimension) and case 2 when we store more then one step,
# and we actually need to take a slice
if return_steps.has_key(0):
if return_steps[0] > 1:
......@@ -969,11 +1055,11 @@ class Scan(Op):
#
def __init__(self,(inputs, outputs, givens, slice_to_seqs),n_seqs, n_outs,
inplace_map={}, seqs_taps={}, outs_taps={},
inplace_map={}, seqs_taps={}, outs_taps={},
n_steps = gof.Constant(gof.generic, 'unknown', '?_steps'),
truncate_gradient = -1, n_outs_not_shared =0,
inner_fn_start_shared = 0, inner_fn_end_shared = 0,
go_backwards = False, store_steps = {},
truncate_gradient = -1, n_outs_not_shared =0,
inner_fn_start_shared = 0, inner_fn_end_shared = 0,
go_backwards = False, store_steps = {},
return_steps={}, mode = None, inplace=False, name = None):
'''
:param (inputs,outputs, givens,slice_to_seqs):
......@@ -1014,7 +1100,7 @@ class Scan(Op):
if inplace:
for i in inplace_map.keys():
# the n_steps is always the first argument of scan's perform,
# so we need to shift everything by 1
# so we need to shift everything by 1
self.destroy_map.update({i: [inplace_map[i]+1] } )
# make all inplace inputs mutable for the inner function for extra efficency
for idx in xrange(len(inputs)):
......@@ -1041,10 +1127,10 @@ class Scan(Op):
self.inner_fn_start_shared = inner_fn_start_shared
self.inner_fn_end_shared = inner_fn_end_shared
self.outputs = outputs
self.n_steps = n_steps # It will be computed at runtime
# This is here just for an optimization to be able to pick up if
# scan is really needed in the graph; if the number of steps
# scan does is a constant of 1, -1 or 0 then we can remove scan
self.n_steps = n_steps # It will be computed at runtime
# This is here just for an optimization to be able to pick up if
# scan is really needed in the graph; if the number of steps
# scan does is a constant of 1, -1 or 0 then we can remove scan
# from the graph
self.mode = mode
self.truncate_gradient = truncate_gradient
......@@ -1346,8 +1432,8 @@ class Scan(Op):
#update outputs
for j in xrange(n_outs):
if self.store_steps[j] <1:
# if you have provided no size for the missing output you might
# find yourself here with a incorect array .. if that happens
# if you have provided no size for the missing output you might
# find yourself here with a incorect array .. if that happens
# realocate memory for the needed array
try :
if hasattr(something[j],'dtype') and (y[j].dtype != \
......@@ -1393,13 +1479,13 @@ class Scan(Op):
# make sure they are given as a list
if not( type(scan_outputs) in (list,tuple)):
scan_outputs = [scan_outputs]
# get a list of clean inputs ( against which one can compute
# get a list of clean inputs ( against which one can compute
# gradients ) [ everything except shared variables with updates ]
clean_inputs = self.inputs[:self.inner_fn_start_shared] + \
self.inputs[self.inner_fn_start_shared + \
self.inner_fn_end_shared:]
clean_inputs = [ self.copy_map.get(x,x) for x in clean_inputs]
s_inputs = [self.copy_map.get(x,x) for x in self.inputs ]
# function that computes the gradient (we sum over the gradients
......@@ -1453,11 +1539,11 @@ class Scan(Op):
if inner_gfn_outs[i] == None:
inner_gfn_outs[i] = tensor.zeros_like(clean_inputs[i])
for i in xrange(self.n_outs_not_shared):
# Safety check
# Safety check
if g_outs[i] == None:
try:
# this try is for catching non ndarray inputs (random states)
# it is more of a safety check ( all random states should be
# it is more of a safety check ( all random states should be
# after n_outs_not_shared ...
g_outs[i] = tensor.zeros_like(scan_outputs[i])
except:
......@@ -1473,9 +1559,9 @@ class Scan(Op):
raise ValueError('Can not compute gradients if one does not ',
'store all intermidiate results (remove store_steps'
'from the dictionaries describing your outputs)')
g_scan = ScanGrad((inner_gfn_ins, inner_gfn_outs),
g_scan = ScanGrad((inner_gfn_ins, inner_gfn_outs),
self.n_seqs, self.n_outs, self.n_outs_not_shared,
self.go_backwards, self.seqs_taps, self.outs_taps,
self.go_backwards, self.seqs_taps, self.outs_taps,
truncate_gradient)
g_scan_outs = g_scan(g_args)
# We need to add several None's fpr shared vars with updates
......@@ -1487,9 +1573,9 @@ class Scan(Op):
class ScanGrad(Op):
"""Gradient Op for Scan"""
def __init__(self,(g_ins, g_outs) , n_seqs, n_outs,
def __init__(self,(g_ins, g_outs) , n_seqs, n_outs,
n_outs_not_shared,
go_backwards = False, seqs_taps = {}, outs_taps= {},
go_backwards = False, seqs_taps = {}, outs_taps= {},
truncate_gradient = -1, mode = None, name = None):
"""
:param mode: see scan fct
......@@ -1643,7 +1729,7 @@ class ScanGrad(Op):
for k in outInfo[:self.n_outs_not_shared]]
g_non_seqs = [numpy.zeros_like(k) for k in non_seqs]
# get gradient on the outputs
# get gradient on the outputs
g_outs = [arg.copy() for arg in args[1:self.n_outs_not_shared+1]]
# get the output of the scan operation
......@@ -1776,8 +1862,8 @@ class ScanSpaceOptimizer(Optimizer):
# look at all its clients
for cl,_dx in out.clients:
if type(cl) == str:
# if the node is actually an output, then
# we need to store the entire thing
# if the node is actually an output, then
# we need to store the entire thing
req_steps = None
break
else:
......@@ -1788,12 +1874,12 @@ class ScanSpaceOptimizer(Optimizer):
req_steps = None
break
else:
# if it is a tensor, and the first
# dimension is just -1
# if it is a tensor, and the first
# dimension is just -1
if cl.op.idx_list[0] == -1 and req_steps != None:
req_steps = numpy.max([1, req_steps])
else:
# or a constant that evaluates to
# or a constant that evaluates to
# -1
try:
idx = opt.get_constant_value(\
......@@ -1810,23 +1896,23 @@ class ScanSpaceOptimizer(Optimizer):
else:
store_steps[i] = op.store_steps[i]
if numpy.any(store_steps!= op.store_steps):
new_scan = Scan((op.inputs, op.outputs, op.givens,
new_scan = Scan((op.inputs, op.outputs, op.givens,
op.slice_to_seqs),op.n_seqs, op.n_outs,
op.inplace_map, op.seqs_taps, op.outs_taps, op.n_steps,
op.truncate_gradient, op.n_outs_not_shared,
op.inner_fn_start_shared, op.inner_fn_end_shared,
op.inplace_map, op.seqs_taps, op.outs_taps, op.n_steps,
op.truncate_gradient, op.n_outs_not_shared,
op.inner_fn_start_shared, op.inner_fn_end_shared,
op.go_backwards, store_steps, op.return_steps, op.mode,
op.inplace, name = op.fn.name).make_node(*node.inputs)
# we not need to replace the outputs of scan
for i,out in enumerate(node.outputs):
# if we are dealing with an output for which
# we changed the number of stored steps we
# if we are dealing with an output for which
# we changed the number of stored steps we
# also need to get rid off the subtensor
if op.store_steps[i] == 0 and store_steps[i] == 1:
# get the output of the subtensor variables
# get the output of the subtensor variables
outSubTens = [ x[0].outputs[0] for x in out.clients ]
new_old = [(x,new_scan.outputs[i]) for x in outSubTens]
env.replace_all_validate(new_old,reason =
env.replace_all_validate(new_old,reason =
'scan_space_optimizer')
else:
env.replace_all_validate([(out,
......@@ -1843,7 +1929,7 @@ def scan_make_inplace(node):
return Scan((op.inputs, op.outputs, op.givens, op.slice_to_seqs ) , op.n_seqs,
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps, op.n_steps,
op.truncate_gradient, op.n_outs_not_shared, op.inner_fn_start_shared,
op.inner_fn_end_shared, op.go_backwards, op.store_steps, op.return_steps,
op.inner_fn_end_shared, op.go_backwards, op.store_steps, op.return_steps,
op.mode, inplace=True, name = op.fn.name).make_node(*node.inputs).outputs
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论