Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
96676ed5
提交
96676ed5
authored
10月 14, 2010
作者:
Razvan Pascanu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[scan][doc][coding-style] re-arranged the documentation of scan parameters
上级
b15fadcc
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
307 行增加
和
221 行删除
+307
-221
scan.py
theano/scan.py
+307
-221
没有找到文件。
theano/scan.py
浏览文件 @
96676ed5
...
@@ -268,164 +268,250 @@ def foldr( fn
...
@@ -268,164 +268,250 @@ def foldr( fn
# Yes, actually it will be exactly 2 ( if there are no other constraints)
# Yes, actually it will be exactly 2 ( if there are no other constraints)
def
scan
(
fn
,
sequences
=
[],
outputs_info
=
[],
non_sequences
=
[],
def
scan
(
fn
n_steps
=
None
,
truncate_gradient
=
-
1
,
go_backwards
=
False
,
,
sequences
=
None
mode
=
None
,
name
=
None
):
,
outputs_info
=
None
"""Function that constructs and applies a Scan op
,
non_sequences
=
None
,
n_steps
=
None
,
truncate_gradient
=
-
1
,
go_backwards
=
False
,
mode
=
None
,
name
=
None
):
"""
This function constructs and applies a Scan op to the provided
arguments.
:param fn:
:param fn:
Function that describes the operations involved in one step of scan
``fn`` is a function that describes the operations involved in one step
Given variables representing all the slices of input and past values of
of ``scan``. ``fn`` should construct variables describing the output of
outputs and other non sequences parameters, ``fn`` should produce
one iteration step. It should expect as input theano variables
variables describing the output of one time step of scan. The order in
representing all the time slices of the input sequences and outputs,
which the argument to this function are given is very important. You
and all other arguments given to scan as ``non_sequences``. The order
should have the following order:
in which scan passes this variables to ``fn`` is the following :
* all time slices of the first sequence (as given in the
* all time slices of the first sequence
``sequences`` list) ordered in the same fashion as the time taps provided
* all time slices of the second sequence
* all time slices of the second sequence (as given in the
``sequences`` list) ordered in the same fashion as the time taps provided
* ...
* ...
* all time slices of the first output (as given in the
* all time slices of the last sequence
``initial_state`` list) ordered in the same fashion as the time taps provided
* all time slices of the first output
* all time slices of the second otuput (as given in the
* all time slices of the second otuput
``initial_state`` list) ordered in the same fashion as the time taps provided
* ...
* ...
* all other parameters over which scan doesn't iterate ordered accordingly
* all time slices of the last output
* all other arguments (the list given as `non_sequences` to
If you are using shared variables over which you do not want to iterate,
scan)
you do not need to provide them as arguments to ``fn``, though you can if you
wish so. The function should return the outputs after each step plus the updates
The order of the sequences is the same as the one in the list
for any of the shared variables. You can either return only outputs or only
`sequences` given to scan. The order of the outputs is the sane
updates. If you have both outputs and updates the function should return
as the order of ``output_info``. For any sequence or output the
them as a tuple : (outputs, updates) or (updates, outputs).
order of the time slices is the same as the order of the time
taps provided. For example if one writes the following :
.. code-block:: python
scan(fn, sequences = [ dict( Sequence1, taps = [-3,2,-1])
, Sequence2
, dict( Sequence3, taps = 3) ]
, outputs_info = [ dict( Output1, taps = [-3,-5])
, dict( Output2, taps = None)
, Output3 ]
, non_sequences = [ Argument1, Argument 2])
``fn`` should expect the following arguments in this given order:
#. ``Sequence1[t-3]``
#. ``Sequence1[t+2]``
#. ``Sequence1[t-1]``
#. ``Sequence2[t]``
#. ``Sequence3[t+3]``
#. ``Output1[t-3]``
#. ``Output1[t-5]``
#. ``Output3[t-1]``
#. ``Argument1``
#. ``Argument2``
The list of ``non_sequences`` can also contain shared variables
used in the function, though ``scan`` is able to figure those
out on its own so they can be skipped. For the clarity of the
code we recommand though to provide them to scan.
The function is expected to return two things. One is a list of
outputs ordered in the same order as ``outputs_info``, with the
difference that there should be only one output variable per
output initial state (even if no tap value is used). Secondly
`fn` should return an update dictionary ( that tells how to
update any shared variable after each iteration ste). The
dictionary can optionally be given as a list of tuples. There is
no constraint on the order of these two list, ``fn`` can return
either ``(outputs_list, update_dictionary)`` or ``(update_dictionary,
outputs_list)`` or just one of the two (in case the other is
empty).
Outputs can be just a theano expression if you have only one output or
a list of theano expressions. Updates can be given either as a list of tuples or
as a dictionary. If you have a list of outputs, the order of these
should match that of their ``initial_states``.
:param sequences:
:param sequences:
list of Theano variables or dictionaries containing Theano variables over which
``sequences`` is the list of Theano variables or dictionaries
scan needs to iterate. The reason you might want to wrap a certain Theano
describing the sequences ``scan`` has to iterate over. If a
variable in a dictionary is to provide auxiliary information about how to iterate
sequence is given as wrapped in a dictionary a set of optional
over that variable. For example this is how you specify that you want to use
information can be provided about the sequence. The dictionary
several time slices of this sequence at each iteration step. The dictionary
should have the following keys:
should have the following keys :
* ``input`` (*mandatory*) -- Theano variable representing the
* ``input`` -- Theano variable representing the sequence
sequence.
* ``taps`` -- temporal taps to use for this sequence. They are given as a list
of ints, where a value ``k`` means that at iteration step ``t`` scan needs to
* ``taps`` -- Temporal taps of the sequence required by ``fn``.
provide also the slice ``t+k`` The order in which you provide these int values
They are provided as a list of integers, where a value ``k`` impiles
here is the same order in which the slices will be provided to ``fn``.
that at iteration step ``t`` scan will pass to ``fn`` the slice
``t+k``. Default value is ``[0]``
If you do not wrap a variable around a dictionary, scan will do it for you, under
the assumption that you use only one slice, defined as a tap of offset 0. This
Any Theano variable in the list ``sequences`` is automatically
means that at step ``t`` scan will provide the slice at position ``t``.
wrapped into a dictionary where ``taps`` is set to ``[0]``
:param outputs_info:
:param outputs_info:
list of Theano variables or dictionaries containing Theano variables used
``outputs_info`` is the list of Theano variables or dictionaries
to initialize the outputs of scan. As before (for ``sequences``) the reason
describing the initial state of the outputs computed
you would wrap a Theano variable in a dictionary is to provide additional
recurrently. When this initial states are given as dictionary
information about how scan should deal with that specific output. The dictionary
optional information can be provided about the output corresponding
should contain the following keys:
to these initial states. The dictionary should have the following
keys:
* ``initial`` -- Theano variable containing the initial state of the output
* ``taps`` -- temporal taps to use for this output. The taps are given as a
* ``initial`` -- Theano variable that represents the initial
list of ints (only negative .. since you can not use future values of outputs),
state of a given output. In case the output is not computed
with the same meaning as for ``sequences`` (see above).
recursively (think of a map) and does not require a initial
* ``inplace`` -- theano variable pointing to one of the input sequences; this
state this field can be skiped. Given that only the previous
flag tells scan that the output should be computed in the memory space occupied
time step of the output is used by ``fn`` the initial state
by that input sequence. Note that scan will only do this if allowed by the
should have the same shape as the output. If multiple time
rest of your computational graph and if you are not using past taps of the
taps are used, the initial state should have one extra
input.
dimension that should cover all the possible taps. For example
* ``return_steps`` how many steps to return from your output. If not given, or
if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0,
0 scan will return all steps, otherwise it will return the last ``return_steps``.
``fn`` will require (by an abuse of notation) ``output[-5]``,
Note that if you set this to something else then 0, scan will try to be smart
``output[-2]`` and ``output[-1]``. This will be given by
about the amount of memory it allocates for a given input.
the initial state, which in this case should have the shape
(5,)+output.shape. If this variable containing the initial
If the function applied recursively uses only the
state is called ``init_y`` then ``init_y[0]`` *corresponds to*
previous value of the output, the initial state should have
``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``,
same shape as one time step of the output; otherwise, the initial state
``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]``
should have the same number of dimension as output. This is easily
coresponds to ``output[-2]``, ``init_y[4]`` corresponds to
understood through an example. For computing ``y[t]`` let us assume that we
``output[-1]``. While this order might seem strange, it comes
need ``y[t-1]``, ``y[t-2]`` and ``y[t-4]``. Through an abuse of
natural from splitting an array at a given point. Assume that
notation, when ``t = 0``, we would need values for ``y[-1]``, ``y[-2]``
we have a array ``x``, and we choose ``k`` to be time step
and ``y[-4]``. These values are provided by the initial state of ``y``,
``0``. Then our initial state would be ``x[:k]``, while the
which should have same number of dimension as ``y``, where the first
output will be ``x[k:]``. Looking at this split, elements in
dimension should be large enough to cover all the required past values, which in
``x[:k]`` are ordered exactly like those in ``init_y``.
this case is 4. If ``init_y`` is the variable containing the initial state
* ``taps`` -- Temporal taps of the output that will be pass to
of ``y``, then ``init_y[0]`` corresponds to ``y[-4]``, ``init_y[1]``
``fn``. They are provided as a list of *negative* integers,
corresponds to ``y[-3]``, ``init_y[2]`` corresponds to ``y[-2]``,
where a value ``k`` implies that at iteration step ``t`` scan will
``init_y[3]`` corresponds to ``y[-1]``. The default behaviour of scan is
pass to ``fn`` the slice ``t+k``.
the following :
* ``inplace`` -- One of the Theano variables provided as
``sequences``. ``scan`` will try to compute this output *in
* if you do not wrap an output in a dictionary, scan will wrap it for you
place* of the provided input *iff* it respects the following
assuming that you use only the last step of the output ( i.e. it makes your tap
constraints:
value list equal to [-1]) and that it is not computed inplace
* if you wrap an output in a dictionary and you do not provide any taps but
* There is no other output that is denied to be computed in
you provide an initial state it will assume that you are using only a tap value
place for whatever reason.
of -1
* if you wrap an output in a dictionary but you do not provide any initial state,
* ``fn`` is not using past taps of the input sequence that
it assumes that you are not using any form of taps
will get overwritten by the output
* if you provide a ``None`` instead of a variable or a dictionary scan assumes
that you will not use any taps for this output (this would be the case for map)
* ``return_steps`` -- Integer representing the number of steps
to return for the current steps. For example, if ``k`` is
If you did not provide any information for your outputs, scan will assume by
provided, ``scan`` will return ``output[-k:]``. This is meant as a
default that you are not using any taps for any of the outputs. If you provide
hint, based on ``k`` and the past taps of the outputs used, scan
information for just a subset of outputs, scan will not know to which outputs
can be smart about the amount of memory it requires to store
these correspond and will raise an error.
intermidiate results. If not given, or ``0``, ``scan`` will return
all computed steps.
* ``store_steps`` -- Integer representing the number of
intermidiate steps ``scan`` should use for a given output. Use
this key only if you really know what you are doing. In general
is recommendat to let scan decide for you the ammount of memory
it should use.
``scan`` will follow this logic if partial information is given:
* If an output is not wrapped in a dictionary, ``scan`` will wrap
it in one assuming that you use only the last step of the output
(i.e. it makes your tap value list equal to [-1]) and that it is
not computed inplace.
* If you wrap an output in a dictionary and you do not provide any
taps but you provide an initial state it will assume that you are
using only a tap value of -1.
* If you wrap an output in a dictionary but you do not provide any
initial state, it assumes that you are not using any form of
taps.
* If you provide a ``None`` instead of a variable or a dictionary
``scan`` assumes that you will not use any taps for this output
(like for example in case of a map)
If ``outputs_info`` is an empty list or None, ``scan`` assumes
that no tap is used for any of the otuputs. If information is
provided just for a subset of the outputs an exception is
raised (because there is no convention on how scan should map
the provided information to the outputs of ``fn``)
:param non_sequences:
:param non_sequences:
Parameters over which scan should not iterate. These parameters are
``non_sequences`` is the list of arguments that are passed to
given at each time step to the function applied recursively.
``fn`` at each steps. Once can opt to exclude shared variables
used in ``fn`` from this list.
:param n_steps:
:param n_steps:
Number of steps to iterate. If the input sequences are not long enough, scan
``n_steps`` is the number of steps to iterate given as an int
will produce a warning and run only for the maximal amount of steps allowed by
or Theano scalar. If any of the input sequences do not have
the input sequences. If the value is 0, the outputs will have 0 rows. If the
enough elements, scan will produce a warning and run only for
value is negative, scan will run backwards (or if the flag go_backwards is
the maximal amount of steps it can. If the *value is 0* the
already set to true it will run forward in time). If n_steps is not provided,
outputs will have *0 rows*. If the value is negative, ``scan``
or evaluetes to None, inf or nan, scan will figure out the maximal amount of
run backwards in time. If the ``go_backwards`` flag is already
steps it can run given the input sequences and do that.
set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n stpes is not provided, or evaluates to ``None``,
``inf`` or ``NaN``, ``scan`` will figure out the amount of
steps it should run given its input sequences.
:param truncate_gradient:
:param truncate_gradient:
Number of steps to use in truncated BPTT. If you compute gradients
``truncate_gradient`` is the number of steps to use in truncated
through a scan op, they are computed using backpropagation through time.
BPTT. If you compute gradients through a scan op, they are
By providing a different value then -1, you choose to use truncated BPTT
computed using backpropagation through time. By providing a
instead of classical BPTT, where you only do ``truncate_gradient``
different value then -1, you choose to use truncated BPTT instead
number of steps.
of classical BPTT, where you go for only ``truncate_gradient``
number of steps back in time.
:param go_backwards:
:param go_backwards:
Flag indicating if you should go backwards through the sequences ( if you
``go_backwards`` is a flag indicating if ``scan`` should go
think as the sequences being indexed by time, this would mean go backwards
backwards through the sequences. If you think of each sequence
in time)
as indexed by time, making this flag True would mean that
``scan`` goes back in time, namely that for any sequence it
starts from the end and goes towards 0.
:param name:
:param name:
The name of the theano function compiled by the Scan op. It will show in the
When profiling ``scan`` it is crucial to provide a name for any
profiler output.
instance of ``scan``. The profiler will produce an overall
profile of your code as well as profiles for doing one iteration
step for each instance of ``scan``. The ``name`` of the instance is
how you differentiate between all these profiles.
:param mode:
:param mode:
The mode used when compiling the theano function in the Scan op.
It is recommended to leave this argument to None, especially
If None, it will use the config mode. If None and the config mode is set to
when profiling ``scan`` (otherwise the results are not going to
profile mode, it we will create a new instance of the ProfileMode in order
be accurate). If you prefer the computations of one step os
to compute the timming correctly.
``scan`` to be done differently then the entire function set
If no new instance is created the time spend in Scan will show up twice in the
this parameters (see ``theano.function`` for details about
profiling, once as the time taken by scan, and the second time as the time
possible values and their meaning).
taken by the ops inside scan. This will be even worse for multiple cascading
scans.
The new profiler instance will be printed when python exits.
:rtype: tuple
:rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a
:return: tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the
Theano variable or a list of Theano variables representing the
outputs of scan. ``updates`` is a dictionary specifying the
outputs of ``scan`` (in the same order as in
``outputs_info``. ``updates`` is a dictionary specifying the
updates rules for all shared variables used in the scan
updates rules for all shared variables used in the scan
operation; this dictionary should be pass to ``theano.function``
operation. This dictionary should be pass to ``theano.function``
when you compile your function.
"""
"""
# General observation : this code is executed only once, at creation
# General observation : this code is executed only once, at creation
# of the computational graph, so we don't yet need to be smart about
# of the computational graph, so we don't yet need to be smart about
# anything ( to speed things up)
# anything ( to speed things up)
# check if inputs are just single variables instead of lists
# check if inputs are just single variables instead of lists
...
@@ -449,7 +535,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -449,7 +535,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# and just apply the inner function once
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
# To do that we check here to see the nature of n_steps
if
type
(
n_steps
)
in
(
float
,
int
):
if
type
(
n_steps
)
in
(
float
,
int
):
n_fixed_steps
=
int
(
n_steps
)
n_fixed_steps
=
int
(
n_steps
)
else
:
else
:
# also check if this value happens to be a constant,
# also check if this value happens to be a constant,
# then we could do the same
# then we could do the same
...
@@ -460,16 +546,16 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -460,16 +546,16 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# compute number of sequences and number of outputs
# compute number of sequences and number of outputs
n_seqs
=
len
(
seqs
)
n_seqs
=
len
(
seqs
)
n_outs
=
len
(
outs_info
)
n_outs
=
len
(
outs_info
)
# initialize the inplace map, sequences map and
# initialize the inplace map, sequences map and
# outputs map
# outputs map
''' Details:
''' Details:
The scan op identifies different properties attached
The scan op identifies different properties attached
to input tensors by their order in the input list.
to input tensors by their order in the input list.
These maps ( inplace, sequence_taps, output_taps,
These maps ( inplace, sequence_taps, output_taps,
store_steps, return_steps) go from the index of an input to
store_steps, return_steps) go from the index of an input to
its properties. Note that inputs are always first, followed
its properties. Note that inputs are always first, followed
by outputs. Since we always know the number of inputs we
by outputs. Since we always know the number of inputs we
index the outputs from 0 ( so sometimes you will need to
index the outputs from 0 ( so sometimes you will need to
do something like outputs_taps[i-n_ins]
do something like outputs_taps[i-n_ins]
'''
'''
inplace_map
=
{}
inplace_map
=
{}
...
@@ -498,13 +584,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -498,13 +584,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# which would indicate that the sequence was provided but
# which would indicate that the sequence was provided but
# not used by the internal function; Only if the user has
# not used by the internal function; Only if the user has
# not provided anything add the defaul [0]
# not provided anything add the defaul [0]
# Possible reason to provide a squence and not use it is
# Possible reason to provide a squence and not use it is
# if you want to compute the output
# if you want to compute the output
# inplace of this input; it is a very unlikely behaviour but
# inplace of this input; it is a very unlikely behaviour but
# we do want to cover it for completeness
# we do want to cover it for completeness
if
not
seqs
[
i
]
.
has_key
(
'taps'
):
if
not
seqs
[
i
]
.
has_key
(
'taps'
):
seqs
[
i
][
taps
]
=
[
0
]
seqs
[
i
][
taps
]
=
[
0
]
# Now that our input is well behaved, collect the taps in the
# Now that our input is well behaved, collect the taps in the
# sequences_taps map that we will use later in the body of scan
# sequences_taps map that we will use later in the body of scan
# since inputs will be just tensors there
# since inputs will be just tensors there
if
seqs
[
i
]
.
get
(
'taps'
,
None
):
if
seqs
[
i
]
.
get
(
'taps'
,
None
):
...
@@ -514,14 +600,14 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -514,14 +600,14 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# in one and in the same pass create a init_outs_taps dictionary and a inplace map
# in one and in the same pass create a init_outs_taps dictionary and a inplace map
for
i
in
xrange
(
n_outs
):
for
i
in
xrange
(
n_outs
):
if
outs_info
[
i
]:
if
outs_info
[
i
]:
# If output is a dictionary, collect the number of steps the
# If output is a dictionary, collect the number of steps the
# user would like scan to return
# user would like scan to return
if
type
(
outs_info
[
i
])
==
dict
:
if
type
(
outs_info
[
i
])
==
dict
:
if
outs_info
[
i
]
.
get
(
'return_steps'
,
None
):
if
outs_info
[
i
]
.
get
(
'return_steps'
,
None
):
return_steps
[
i
]
=
outs_info
[
i
][
'return_steps'
]
return_steps
[
i
]
=
outs_info
[
i
][
'return_steps'
]
# If you provide the number of steps to store internally,
# If you provide the number of steps to store internally,
# (not advocated in the user documentation), then also
# (not advocated in the user documentation), then also
# make sure you are returning only those number of steps
# make sure you are returning only those number of steps
if
outs_info
[
i
]
.
get
(
'store_steps'
,
None
):
if
outs_info
[
i
]
.
get
(
'store_steps'
,
None
):
store_steps
+=
[
outs_info
[
i
]
.
get
(
'store_steps'
,
None
)]
store_steps
+=
[
outs_info
[
i
]
.
get
(
'store_steps'
,
None
)]
return_steps
[
i
]
=
outs_info
[
i
]
.
get
(
'store_steps'
,
None
)
return_steps
[
i
]
=
outs_info
[
i
]
.
get
(
'store_steps'
,
None
)
...
@@ -540,11 +626,11 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -540,11 +626,11 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
(
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
(
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
raise
ValueError
(
'If you are using slices of an output you need to '
\
raise
ValueError
(
'If you are using slices of an output you need to '
\
'provide a initial state for it'
,
outs_info
[
i
])
'provide a initial state for it'
,
outs_info
[
i
])
# if there is an intial state but no tap, we will add the default value
# if there is an intial state but no tap, we will add the default value
# for taps, namely [-1] ( previous value); not that this will happen
# for taps, namely [-1] ( previous value); not that this will happen
# even though you have provided for taps the value None, which is a bit
# even though you have provided for taps the value None, which is a bit
# strange (why would one provide an initial state but tell scan not to
# strange (why would one provide an initial state but tell scan not to
# use it ? ), just that in that case we will throw in a warning message
# use it ? ), just that in that case we will throw in a warning message
# pointing out this inconsistency
# pointing out this inconsistency
elif
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
\
elif
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
\
(
not
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
(
not
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
...
@@ -556,18 +642,18 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -556,18 +642,18 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
'provide the initial state'
)
'provide the initial state'
)
outs_info
[
i
][
'taps'
]
=
[
-
1
]
outs_info
[
i
][
'taps'
]
=
[
-
1
]
else
:
else
:
# if the output is a None then replace it with an empty dictionary for
# if the output is a None then replace it with an empty dictionary for
# easing up dealing with this case later one ( we can directly call .has_key
# easing up dealing with this case later one ( we can directly call .has_key
# and things like this
# and things like this
outs_info
[
i
]
=
dict
()
outs_info
[
i
]
=
dict
()
store_steps
+=
[
0
]
store_steps
+=
[
0
]
if
outs_info
[
i
]
.
get
(
'taps'
,
None
):
if
outs_info
[
i
]
.
get
(
'taps'
,
None
):
# Create a separate outputs_taps dictionary with all the outputs taps; This
# Create a separate outputs_taps dictionary with all the outputs taps; This
# is how the Scan Op expects this information, separeted from the variables
# is how the Scan Op expects this information, separeted from the variables
outputs_taps
[
i
]
=
outs_info
[
i
][
'taps'
]
outputs_taps
[
i
]
=
outs_info
[
i
][
'taps'
]
if
outs_info
[
i
]
.
get
(
'inplace'
,
None
):
if
outs_info
[
i
]
.
get
(
'inplace'
,
None
):
# The same is true for the inplace info; it has to go into a separate
# The same is true for the inplace info; it has to go into a separate
# dictionary based on index; Note that the input we're replacing should also
# dictionary based on index; Note that the input we're replacing should also
# come as an index, therefore we have to look for it at this point
# come as an index, therefore we have to look for it at this point
found
=
None
found
=
None
...
@@ -575,7 +661,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -575,7 +661,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if
seqs
[
k
]
.
get
(
'input'
,
None
)
==
outs_info
[
i
]
.
get
(
'inplace'
,
None
):
if
seqs
[
k
]
.
get
(
'input'
,
None
)
==
outs_info
[
i
]
.
get
(
'inplace'
,
None
):
found
=
k
found
=
k
if
found
!=
None
:
if
found
!=
None
:
# NOTE : inplace_map is identical to destroy_map, i.e. it tells what
# NOTE : inplace_map is identical to destroy_map, i.e. it tells what
# output is computed inplace of what input !!
# output is computed inplace of what input !!
inplace_map
[
i
]
=
found
inplace_map
[
i
]
=
found
else
:
else
:
...
@@ -602,12 +688,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -602,12 +688,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# create one slice of the input
# create one slice of the input
'''
'''
Later on, if we decide not to use scan because we are going
Later on, if we decide not to use scan because we are going
for just one step, it makes things easier if we compute the
for just one step, it makes things easier if we compute the
correct outputs here. This way we can use the output of the
correct outputs here. This way we can use the output of the
lambda expression directly to replace the output of scan.
lambda expression directly to replace the output of scan.
If not we need to use copies, that will be replaced at each
If not we need to use copies, that will be replaced at each
frame by the corresponding slice
frame by the corresponding slice
'''
'''
if
n_fixed_steps
not
in
[
1
,
-
1
]:
if
n_fixed_steps
not
in
[
1
,
-
1
]:
nw_slice
=
seq
[
'input'
][
0
]
.
type
()
nw_slice
=
seq
[
'input'
][
0
]
.
type
()
...
@@ -625,10 +711,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -625,10 +711,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
else
:
else
:
nw_slice
.
name
=
seq
[
'input'
]
.
name
+
'[t
%
d]'
%
seq
[
'taps'
][
k
]
nw_slice
.
name
=
seq
[
'input'
]
.
name
+
'[t
%
d]'
%
seq
[
'taps'
][
k
]
args
.
append
(
nw_slice
)
args
.
append
(
nw_slice
)
# Specify to whom this slice belongs
# Specify to whom this slice belongs
slice_to_seqs
.
append
(
i
)
slice_to_seqs
.
append
(
i
)
# Any slice is not a shared variable, even though the sequence
# Any slice is not a shared variable, even though the sequence
# from where we pick the slices is shared, therefore we should
# from where we pick the slices is shared, therefore we should
# increase the number of notshared inputs to the dummy function
# increase the number of notshared inputs to the dummy function
# by the number of slices
# by the number of slices
dummy_notshared_ins
+=
len
(
seq
[
'taps'
])
dummy_notshared_ins
+=
len
(
seq
[
'taps'
])
...
@@ -636,7 +722,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -636,7 +722,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
for
i
,
init_out
in
enumerate
(
outs_info
):
for
i
,
init_out
in
enumerate
(
outs_info
):
# Note that our convention dictates that if an output uses
# Note that our convention dictates that if an output uses
# just the previous time step, as a initial state we will only provide
# just the previous time step, as a initial state we will only provide
# a tensor of the same dimension as one time step; This makes code
# a tensor of the same dimension as one time step; This makes code
# much cleaner for those who do not use taps. Otherwise they would
# much cleaner for those who do not use taps. Otherwise they would
# always had to shape_pad_left the initial state .. which is ugly
# always had to shape_pad_left the initial state .. which is ugly
if
init_out
.
get
(
'taps'
,
None
)
==
[
-
1
]:
if
init_out
.
get
(
'taps'
,
None
)
==
[
-
1
]:
...
@@ -647,9 +733,9 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -647,9 +733,9 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# Added name to slices for debugging and pretty printing
# Added name to slices for debugging and pretty printing
if
init_out
[
'initial'
]
.
name
:
if
init_out
[
'initial'
]
.
name
:
args
[
-
1
]
.
name
=
init_out
[
'initial'
]
.
name
+
'[t-1]'
args
[
-
1
]
.
name
=
init_out
[
'initial'
]
.
name
+
'[t-1]'
# we need to specify in slice_seqs to which output this
# we need to specify in slice_seqs to which output this
# slice belongs; Because we might get confused afterwards
# slice belongs; Because we might get confused afterwards
# if a number is an index of a sequence or an output, and
# if a number is an index of a sequence or an output, and
# because we do not want to create yet another list, we will
# because we do not want to create yet another list, we will
# add the number of sequences + the current output. This makes
# add the number of sequences + the current output. This makes
# decoding easy and spares us from writing a lot of lines
# decoding easy and spares us from writing a lot of lines
...
@@ -682,11 +768,11 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -682,11 +768,11 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# add as many slices as there are taps
# add as many slices as there are taps
dummy_notshared_init_outs
+=
len
(
init_out
[
'taps'
])
dummy_notshared_init_outs
+=
len
(
init_out
[
'taps'
])
#NOTE: there is another case, in which we do not want to provide any previous
#NOTE: there is another case, in which we do not want to provide any previous
# value of the output to the inner case; in this case we do not have to do
# value of the output to the inner case; in this case we do not have to do
# anything ..
# anything ..
# remove shared variables from the non sequences list
# remove shared variables from the non sequences list
# such that we can compile the function ( the user has the option to add them when
# such that we can compile the function ( the user has the option to add them when
# writing scan, because in some situations this might make the code more readable)
# writing scan, because in some situations this might make the code more readable)
notshared_other_args
=
[]
notshared_other_args
=
[]
for
non_seq
in
non_seqs
:
for
non_seq
in
non_seqs
:
...
@@ -707,7 +793,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -707,7 +793,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# when we apply the lambda expression we get a mixture of update rules
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
# and outputs that needs to be separated
outputs_updates
=
fn
(
*
args
)
outputs_updates
=
fn
(
*
args
)
# The code that follows tries to be as flexible as possible allowing the
# The code that follows tries to be as flexible as possible allowing the
# user to return the output and updates in any order, and giving the updates
# user to return the output and updates in any order, and giving the updates
# however he wants ( as a dictionary or a list o pairs ..)
# however he wants ( as a dictionary or a list o pairs ..)
# Is there a way to compress all this by writing it in a more python/functional way?
# Is there a way to compress all this by writing it in a more python/functional way?
...
@@ -747,7 +833,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -747,7 +833,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
outputs
=
outputs_updates
outputs
=
outputs_updates
updates
=
{}
updates
=
{}
# in case you return a tuple .. convert it to a list (there are certain
# in case you return a tuple .. convert it to a list (there are certain
# operation that are not permited on tuples, like element assignment)
# operation that are not permited on tuples, like element assignment)
outputs
=
list
(
outputs
)
outputs
=
list
(
outputs
)
...
@@ -765,12 +851,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -765,12 +851,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# so we can do stuff as unoptimal as we wish ]
# so we can do stuff as unoptimal as we wish ]
if
n_fixed_steps
in
[
-
1
,
1
]:
if
n_fixed_steps
in
[
-
1
,
1
]:
''' We do have a special case here, namely is so might happen that
''' We do have a special case here, namely is so might happen that
whatever we have in dummy_args is not sufficient to compile the
whatever we have in dummy_args is not sufficient to compile the
function( i.e. missing inputs). Furthermore we might not even need
function( i.e. missing inputs). Furthermore we might not even need
to compile the function here for this special case. But due to the
to compile the function here for this special case. But due to the
way I wrote the code is easier to have a compiled function here
way I wrote the code is easier to have a compiled function here
that I can ignore later. Plus it is easier this way to take care
that I can ignore later. Plus it is easier this way to take care
of shared variables with non-default updates. Therefore only for
of shared variables with non-default updates. Therefore only for
this case I need to use gof.graph.inputs to look for the real inputs
this case I need to use gof.graph.inputs to look for the real inputs
so that I can compile the function. RP '''
so that I can compile the function. RP '''
dummy_f
=
function
(
filter
(
lambda
x
:
isinstance
(
x
,
gof
.
Variable
)
and
\
dummy_f
=
function
(
filter
(
lambda
x
:
isinstance
(
x
,
gof
.
Variable
)
and
\
...
@@ -802,12 +888,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -802,12 +888,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# other updates :
# other updates :
for
i
in
xrange
(
n_outs
):
for
i
in
xrange
(
n_outs
):
outs_info
+=
[
dict
()
]
outs_info
+=
[
dict
()
]
# we also need to re-initialize the store_steps list to match the
# we also need to re-initialize the store_steps list to match the
# number of outputs
# number of outputs
store_steps
=
[
0
for
i
in
xrange
(
n_outs
)]
store_steps
=
[
0
for
i
in
xrange
(
n_outs
)]
else
:
else
:
# Otherwise there is a bit of confusion, since Scan works on the index of
# Otherwise there is a bit of confusion, since Scan works on the index of
# a sequence /output. There are maybe corner cases that could be added here
# a sequence /output. There are maybe corner cases that could be added here
# or defult behaviour ( like always add the extra outputs at the end !?)
# or defult behaviour ( like always add the extra outputs at the end !?)
# But I did not bother implementing this, I leave it to the user to clearly
# But I did not bother implementing this, I leave it to the user to clearly
...
@@ -832,7 +918,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -832,7 +918,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
fromIdx
=
dummy_notshared_ins
+
dummy_notshared_init_outs
fromIdx
=
dummy_notshared_ins
+
dummy_notshared_init_outs
copy_map
=
{}
copy_map
=
{}
for
input
in
dummy_f
.
maker
.
expanded_inputs
[
fromIdx
:]
:
for
input
in
dummy_f
.
maker
.
expanded_inputs
[
fromIdx
:]
:
# If input is a shared variable that gets updated, then
# If input is a shared variable that gets updated, then
# this shared variable will be an output of our inner function
# this shared variable will be an output of our inner function
if
isinstance
(
input
.
variable
,
SharedVariable
)
and
input
.
update
:
if
isinstance
(
input
.
variable
,
SharedVariable
)
and
input
.
update
:
# Create a copy of it
# Create a copy of it
...
@@ -857,8 +943,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -857,8 +943,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# inner_fn_shared_ins_idx stores where we stop having shared variables with updates
# inner_fn_shared_ins_idx stores where we stop having shared variables with updates
inner_fn_shared_ins_idx
=
len
(
inner_fn_inputs
)
-
inner_fn_notshared_ins_idx
inner_fn_shared_ins_idx
=
len
(
inner_fn_inputs
)
-
inner_fn_notshared_ins_idx
# Now that we took out the shared variables that have an update rule
# Now that we took out the shared variables that have an update rule
# we need to take care of all the other shared variables
# we need to take care of all the other shared variables
for
input
in
dummy_f
.
maker
.
expanded_inputs
[
fromIdx
:]
:
for
input
in
dummy_f
.
maker
.
expanded_inputs
[
fromIdx
:]
:
# make sure that we do not add the same shared variable twice
# make sure that we do not add the same shared variable twice
if
isinstance
(
input
.
variable
,
SharedVariable
)
and
not
input
.
update
:
if
isinstance
(
input
.
variable
,
SharedVariable
)
and
not
input
.
update
:
...
@@ -871,14 +957,14 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -871,14 +957,14 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
givens
[
input
.
variable
]
=
inner_fn_inputs
[
-
1
]
givens
[
input
.
variable
]
=
inner_fn_inputs
[
-
1
]
copy_map
[
inner_fn_inputs
[
-
1
]]
=
input
.
variable
copy_map
[
inner_fn_inputs
[
-
1
]]
=
input
.
variable
elif
not
isinstance
(
input
.
variable
,
SharedVariable
):
elif
not
isinstance
(
input
.
variable
,
SharedVariable
):
# also add the normal tensor that are non sequences at the
# also add the normal tensor that are non sequences at the
# end of the inputs intertwingled with the shared variables
# end of the inputs intertwingled with the shared variables
inner_fn_inputs
.
append
(
input
.
variable
)
inner_fn_inputs
.
append
(
input
.
variable
)
# If we haven't provided a number of steps nor did we provide a sequence
# If we haven't provided a number of steps nor did we provide a sequence
# scan will not know how long to iterate
# scan will not know how long to iterate
if
(
n_steps
==
None
or
n_steps
==
numpy
.
inf
or
n_steps
==
numpy
.
nan
)
and
n_seqs
==
0
:
if
(
n_steps
==
None
or
n_steps
==
numpy
.
inf
or
n_steps
==
numpy
.
nan
)
and
n_seqs
==
0
:
raise
ValueError
(
'Scan does not know for how many steps to iterate. '
raise
ValueError
(
'Scan does not know for how many steps to iterate. '
'You need to provide the number of steps through the '
'You need to provide the number of steps through the '
' ``n_steps`` argument if you do not iterate over any sequence'
)
' ``n_steps`` argument if you do not iterate over any sequence'
)
...
@@ -925,19 +1011,19 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -925,19 +1011,19 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if
not
type
(
values
)
in
(
tuple
,
list
):
if
not
type
(
values
)
in
(
tuple
,
list
):
values
=
[
values
]
values
=
[
values
]
# take out the updates of shared variable and build the dictionary
# take out the updates of shared variable and build the dictionary
# that tells what to update and with what value
# that tells what to update and with what value
for
val
in
update_map
.
keys
():
for
val
in
update_map
.
keys
():
update_map
[
val
]
=
values
[
update_map
[
val
]
]
update_map
[
val
]
=
values
[
update_map
[
val
]
]
# Now we need to check the values returned
# Now we need to check the values returned
# if it just one strip the list around it
# if it just one strip the list around it
if
n_outs
==
1
:
if
n_outs
==
1
:
# if we need to return just one step or several steps
# if we need to return just one step or several steps
# note that when we return one step we have two cases, in
# note that when we return one step we have two cases, in
# the first one store_steps is set to 1, case in which we don't
# the first one store_steps is set to 1, case in which we don't
# need to take a slice of the output (is already of the right
# need to take a slice of the output (is already of the right
# dimension) and case 2 when we store more then one step,
# dimension) and case 2 when we store more then one step,
# and we actually need to take a slice
# and we actually need to take a slice
if
return_steps
.
has_key
(
0
):
if
return_steps
.
has_key
(
0
):
if
return_steps
[
0
]
>
1
:
if
return_steps
[
0
]
>
1
:
...
@@ -969,11 +1055,11 @@ class Scan(Op):
...
@@ -969,11 +1055,11 @@ class Scan(Op):
#
#
def
__init__
(
self
,(
inputs
,
outputs
,
givens
,
slice_to_seqs
),
n_seqs
,
n_outs
,
def
__init__
(
self
,(
inputs
,
outputs
,
givens
,
slice_to_seqs
),
n_seqs
,
n_outs
,
inplace_map
=
{},
seqs_taps
=
{},
outs_taps
=
{},
inplace_map
=
{},
seqs_taps
=
{},
outs_taps
=
{},
n_steps
=
gof
.
Constant
(
gof
.
generic
,
'unknown'
,
'?_steps'
),
n_steps
=
gof
.
Constant
(
gof
.
generic
,
'unknown'
,
'?_steps'
),
truncate_gradient
=
-
1
,
n_outs_not_shared
=
0
,
truncate_gradient
=
-
1
,
n_outs_not_shared
=
0
,
inner_fn_start_shared
=
0
,
inner_fn_end_shared
=
0
,
inner_fn_start_shared
=
0
,
inner_fn_end_shared
=
0
,
go_backwards
=
False
,
store_steps
=
{},
go_backwards
=
False
,
store_steps
=
{},
return_steps
=
{},
mode
=
None
,
inplace
=
False
,
name
=
None
):
return_steps
=
{},
mode
=
None
,
inplace
=
False
,
name
=
None
):
'''
'''
:param (inputs,outputs, givens,slice_to_seqs):
:param (inputs,outputs, givens,slice_to_seqs):
...
@@ -1014,7 +1100,7 @@ class Scan(Op):
...
@@ -1014,7 +1100,7 @@ class Scan(Op):
if
inplace
:
if
inplace
:
for
i
in
inplace_map
.
keys
():
for
i
in
inplace_map
.
keys
():
# the n_steps is always the first argument of scan's perform,
# the n_steps is always the first argument of scan's perform,
# so we need to shift everything by 1
# so we need to shift everything by 1
self
.
destroy_map
.
update
({
i
:
[
inplace_map
[
i
]
+
1
]
}
)
self
.
destroy_map
.
update
({
i
:
[
inplace_map
[
i
]
+
1
]
}
)
# make all inplace inputs mutable for the inner function for extra efficency
# make all inplace inputs mutable for the inner function for extra efficency
for
idx
in
xrange
(
len
(
inputs
)):
for
idx
in
xrange
(
len
(
inputs
)):
...
@@ -1041,10 +1127,10 @@ class Scan(Op):
...
@@ -1041,10 +1127,10 @@ class Scan(Op):
self
.
inner_fn_start_shared
=
inner_fn_start_shared
self
.
inner_fn_start_shared
=
inner_fn_start_shared
self
.
inner_fn_end_shared
=
inner_fn_end_shared
self
.
inner_fn_end_shared
=
inner_fn_end_shared
self
.
outputs
=
outputs
self
.
outputs
=
outputs
self
.
n_steps
=
n_steps
# It will be computed at runtime
self
.
n_steps
=
n_steps
# It will be computed at runtime
# This is here just for an optimization to be able to pick up if
# This is here just for an optimization to be able to pick up if
# scan is really needed in the graph; if the number of steps
# scan is really needed in the graph; if the number of steps
# scan does is a constant of 1, -1 or 0 then we can remove scan
# scan does is a constant of 1, -1 or 0 then we can remove scan
# from the graph
# from the graph
self
.
mode
=
mode
self
.
mode
=
mode
self
.
truncate_gradient
=
truncate_gradient
self
.
truncate_gradient
=
truncate_gradient
...
@@ -1346,8 +1432,8 @@ class Scan(Op):
...
@@ -1346,8 +1432,8 @@ class Scan(Op):
#update outputs
#update outputs
for
j
in
xrange
(
n_outs
):
for
j
in
xrange
(
n_outs
):
if
self
.
store_steps
[
j
]
<
1
:
if
self
.
store_steps
[
j
]
<
1
:
# if you have provided no size for the missing output you might
# if you have provided no size for the missing output you might
# find yourself here with a incorect array .. if that happens
# find yourself here with a incorect array .. if that happens
# realocate memory for the needed array
# realocate memory for the needed array
try
:
try
:
if
hasattr
(
something
[
j
],
'dtype'
)
and
(
y
[
j
]
.
dtype
!=
\
if
hasattr
(
something
[
j
],
'dtype'
)
and
(
y
[
j
]
.
dtype
!=
\
...
@@ -1393,13 +1479,13 @@ class Scan(Op):
...
@@ -1393,13 +1479,13 @@ class Scan(Op):
# make sure they are given as a list
# make sure they are given as a list
if
not
(
type
(
scan_outputs
)
in
(
list
,
tuple
)):
if
not
(
type
(
scan_outputs
)
in
(
list
,
tuple
)):
scan_outputs
=
[
scan_outputs
]
scan_outputs
=
[
scan_outputs
]
# get a list of clean inputs ( against which one can compute
# get a list of clean inputs ( against which one can compute
# gradients ) [ everything except shared variables with updates ]
# gradients ) [ everything except shared variables with updates ]
clean_inputs
=
self
.
inputs
[:
self
.
inner_fn_start_shared
]
+
\
clean_inputs
=
self
.
inputs
[:
self
.
inner_fn_start_shared
]
+
\
self
.
inputs
[
self
.
inner_fn_start_shared
+
\
self
.
inputs
[
self
.
inner_fn_start_shared
+
\
self
.
inner_fn_end_shared
:]
self
.
inner_fn_end_shared
:]
clean_inputs
=
[
self
.
copy_map
.
get
(
x
,
x
)
for
x
in
clean_inputs
]
clean_inputs
=
[
self
.
copy_map
.
get
(
x
,
x
)
for
x
in
clean_inputs
]
s_inputs
=
[
self
.
copy_map
.
get
(
x
,
x
)
for
x
in
self
.
inputs
]
s_inputs
=
[
self
.
copy_map
.
get
(
x
,
x
)
for
x
in
self
.
inputs
]
# function that computes the gradient (we sum over the gradients
# function that computes the gradient (we sum over the gradients
...
@@ -1453,11 +1539,11 @@ class Scan(Op):
...
@@ -1453,11 +1539,11 @@ class Scan(Op):
if
inner_gfn_outs
[
i
]
==
None
:
if
inner_gfn_outs
[
i
]
==
None
:
inner_gfn_outs
[
i
]
=
tensor
.
zeros_like
(
clean_inputs
[
i
])
inner_gfn_outs
[
i
]
=
tensor
.
zeros_like
(
clean_inputs
[
i
])
for
i
in
xrange
(
self
.
n_outs_not_shared
):
for
i
in
xrange
(
self
.
n_outs_not_shared
):
# Safety check
# Safety check
if
g_outs
[
i
]
==
None
:
if
g_outs
[
i
]
==
None
:
try
:
try
:
# this try is for catching non ndarray inputs (random states)
# this try is for catching non ndarray inputs (random states)
# it is more of a safety check ( all random states should be
# it is more of a safety check ( all random states should be
# after n_outs_not_shared ...
# after n_outs_not_shared ...
g_outs
[
i
]
=
tensor
.
zeros_like
(
scan_outputs
[
i
])
g_outs
[
i
]
=
tensor
.
zeros_like
(
scan_outputs
[
i
])
except
:
except
:
...
@@ -1473,9 +1559,9 @@ class Scan(Op):
...
@@ -1473,9 +1559,9 @@ class Scan(Op):
raise
ValueError
(
'Can not compute gradients if one does not '
,
raise
ValueError
(
'Can not compute gradients if one does not '
,
'store all intermidiate results (remove store_steps'
'store all intermidiate results (remove store_steps'
'from the dictionaries describing your outputs)'
)
'from the dictionaries describing your outputs)'
)
g_scan
=
ScanGrad
((
inner_gfn_ins
,
inner_gfn_outs
),
g_scan
=
ScanGrad
((
inner_gfn_ins
,
inner_gfn_outs
),
self
.
n_seqs
,
self
.
n_outs
,
self
.
n_outs_not_shared
,
self
.
n_seqs
,
self
.
n_outs
,
self
.
n_outs_not_shared
,
self
.
go_backwards
,
self
.
seqs_taps
,
self
.
outs_taps
,
self
.
go_backwards
,
self
.
seqs_taps
,
self
.
outs_taps
,
truncate_gradient
)
truncate_gradient
)
g_scan_outs
=
g_scan
(
g_args
)
g_scan_outs
=
g_scan
(
g_args
)
# We need to add several None's fpr shared vars with updates
# We need to add several None's fpr shared vars with updates
...
@@ -1487,9 +1573,9 @@ class Scan(Op):
...
@@ -1487,9 +1573,9 @@ class Scan(Op):
class
ScanGrad
(
Op
):
class
ScanGrad
(
Op
):
"""Gradient Op for Scan"""
"""Gradient Op for Scan"""
def
__init__
(
self
,(
g_ins
,
g_outs
)
,
n_seqs
,
n_outs
,
def
__init__
(
self
,(
g_ins
,
g_outs
)
,
n_seqs
,
n_outs
,
n_outs_not_shared
,
n_outs_not_shared
,
go_backwards
=
False
,
seqs_taps
=
{},
outs_taps
=
{},
go_backwards
=
False
,
seqs_taps
=
{},
outs_taps
=
{},
truncate_gradient
=
-
1
,
mode
=
None
,
name
=
None
):
truncate_gradient
=
-
1
,
mode
=
None
,
name
=
None
):
"""
"""
:param mode: see scan fct
:param mode: see scan fct
...
@@ -1643,7 +1729,7 @@ class ScanGrad(Op):
...
@@ -1643,7 +1729,7 @@ class ScanGrad(Op):
for
k
in
outInfo
[:
self
.
n_outs_not_shared
]]
for
k
in
outInfo
[:
self
.
n_outs_not_shared
]]
g_non_seqs
=
[
numpy
.
zeros_like
(
k
)
for
k
in
non_seqs
]
g_non_seqs
=
[
numpy
.
zeros_like
(
k
)
for
k
in
non_seqs
]
# get gradient on the outputs
# get gradient on the outputs
g_outs
=
[
arg
.
copy
()
for
arg
in
args
[
1
:
self
.
n_outs_not_shared
+
1
]]
g_outs
=
[
arg
.
copy
()
for
arg
in
args
[
1
:
self
.
n_outs_not_shared
+
1
]]
# get the output of the scan operation
# get the output of the scan operation
...
@@ -1776,8 +1862,8 @@ class ScanSpaceOptimizer(Optimizer):
...
@@ -1776,8 +1862,8 @@ class ScanSpaceOptimizer(Optimizer):
# look at all its clients
# look at all its clients
for
cl
,
_dx
in
out
.
clients
:
for
cl
,
_dx
in
out
.
clients
:
if
type
(
cl
)
==
str
:
if
type
(
cl
)
==
str
:
# if the node is actually an output, then
# if the node is actually an output, then
# we need to store the entire thing
# we need to store the entire thing
req_steps
=
None
req_steps
=
None
break
break
else
:
else
:
...
@@ -1788,12 +1874,12 @@ class ScanSpaceOptimizer(Optimizer):
...
@@ -1788,12 +1874,12 @@ class ScanSpaceOptimizer(Optimizer):
req_steps
=
None
req_steps
=
None
break
break
else
:
else
:
# if it is a tensor, and the first
# if it is a tensor, and the first
# dimension is just -1
# dimension is just -1
if
cl
.
op
.
idx_list
[
0
]
==
-
1
and
req_steps
!=
None
:
if
cl
.
op
.
idx_list
[
0
]
==
-
1
and
req_steps
!=
None
:
req_steps
=
numpy
.
max
([
1
,
req_steps
])
req_steps
=
numpy
.
max
([
1
,
req_steps
])
else
:
else
:
# or a constant that evaluates to
# or a constant that evaluates to
# -1
# -1
try
:
try
:
idx
=
opt
.
get_constant_value
(
\
idx
=
opt
.
get_constant_value
(
\
...
@@ -1810,23 +1896,23 @@ class ScanSpaceOptimizer(Optimizer):
...
@@ -1810,23 +1896,23 @@ class ScanSpaceOptimizer(Optimizer):
else
:
else
:
store_steps
[
i
]
=
op
.
store_steps
[
i
]
store_steps
[
i
]
=
op
.
store_steps
[
i
]
if
numpy
.
any
(
store_steps
!=
op
.
store_steps
):
if
numpy
.
any
(
store_steps
!=
op
.
store_steps
):
new_scan
=
Scan
((
op
.
inputs
,
op
.
outputs
,
op
.
givens
,
new_scan
=
Scan
((
op
.
inputs
,
op
.
outputs
,
op
.
givens
,
op
.
slice_to_seqs
),
op
.
n_seqs
,
op
.
n_outs
,
op
.
slice_to_seqs
),
op
.
n_seqs
,
op
.
n_outs
,
op
.
inplace_map
,
op
.
seqs_taps
,
op
.
outs_taps
,
op
.
n_steps
,
op
.
inplace_map
,
op
.
seqs_taps
,
op
.
outs_taps
,
op
.
n_steps
,
op
.
truncate_gradient
,
op
.
n_outs_not_shared
,
op
.
truncate_gradient
,
op
.
n_outs_not_shared
,
op
.
inner_fn_start_shared
,
op
.
inner_fn_end_shared
,
op
.
inner_fn_start_shared
,
op
.
inner_fn_end_shared
,
op
.
go_backwards
,
store_steps
,
op
.
return_steps
,
op
.
mode
,
op
.
go_backwards
,
store_steps
,
op
.
return_steps
,
op
.
mode
,
op
.
inplace
,
name
=
op
.
fn
.
name
)
.
make_node
(
*
node
.
inputs
)
op
.
inplace
,
name
=
op
.
fn
.
name
)
.
make_node
(
*
node
.
inputs
)
# we not need to replace the outputs of scan
# we not need to replace the outputs of scan
for
i
,
out
in
enumerate
(
node
.
outputs
):
for
i
,
out
in
enumerate
(
node
.
outputs
):
# if we are dealing with an output for which
# if we are dealing with an output for which
# we changed the number of stored steps we
# we changed the number of stored steps we
# also need to get rid off the subtensor
# also need to get rid off the subtensor
if
op
.
store_steps
[
i
]
==
0
and
store_steps
[
i
]
==
1
:
if
op
.
store_steps
[
i
]
==
0
and
store_steps
[
i
]
==
1
:
# get the output of the subtensor variables
# get the output of the subtensor variables
outSubTens
=
[
x
[
0
]
.
outputs
[
0
]
for
x
in
out
.
clients
]
outSubTens
=
[
x
[
0
]
.
outputs
[
0
]
for
x
in
out
.
clients
]
new_old
=
[(
x
,
new_scan
.
outputs
[
i
])
for
x
in
outSubTens
]
new_old
=
[(
x
,
new_scan
.
outputs
[
i
])
for
x
in
outSubTens
]
env
.
replace_all_validate
(
new_old
,
reason
=
env
.
replace_all_validate
(
new_old
,
reason
=
'scan_space_optimizer'
)
'scan_space_optimizer'
)
else
:
else
:
env
.
replace_all_validate
([(
out
,
env
.
replace_all_validate
([(
out
,
...
@@ -1843,7 +1929,7 @@ def scan_make_inplace(node):
...
@@ -1843,7 +1929,7 @@ def scan_make_inplace(node):
return
Scan
((
op
.
inputs
,
op
.
outputs
,
op
.
givens
,
op
.
slice_to_seqs
)
,
op
.
n_seqs
,
return
Scan
((
op
.
inputs
,
op
.
outputs
,
op
.
givens
,
op
.
slice_to_seqs
)
,
op
.
n_seqs
,
op
.
n_outs
,
op
.
inplace_map
,
op
.
seqs_taps
,
op
.
outs_taps
,
op
.
n_steps
,
op
.
n_outs
,
op
.
inplace_map
,
op
.
seqs_taps
,
op
.
outs_taps
,
op
.
n_steps
,
op
.
truncate_gradient
,
op
.
n_outs_not_shared
,
op
.
inner_fn_start_shared
,
op
.
truncate_gradient
,
op
.
n_outs_not_shared
,
op
.
inner_fn_start_shared
,
op
.
inner_fn_end_shared
,
op
.
go_backwards
,
op
.
store_steps
,
op
.
return_steps
,
op
.
inner_fn_end_shared
,
op
.
go_backwards
,
op
.
store_steps
,
op
.
return_steps
,
op
.
mode
,
inplace
=
True
,
name
=
op
.
fn
.
name
)
.
make_node
(
*
node
.
inputs
)
.
outputs
op
.
mode
,
inplace
=
True
,
name
=
op
.
fn
.
name
)
.
make_node
(
*
node
.
inputs
)
.
outputs
return
False
return
False
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论