Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
029699ca
提交
029699ca
authored
3月 03, 2010
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Removed trailing spaces
上级
c03c55c1
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
154 行增加
和
154 行删除
+154
-154
scan.py
theano/scan.py
+154
-154
没有找到文件。
theano/scan.py
浏览文件 @
029699ca
"""This module provides the Scan Op
"""This module provides the Scan Op
Scanning is a general form of recurrence, which can be used for looping.
Scanning is a general form of recurrence, which can be used for looping.
The idea is that you *scan* a function along some input sequence, producing
The idea is that you *scan* a function along some input sequence, producing
an output at each time-step that can be seen (but not modified) by the
an output at each time-step that can be seen (but not modified) by the
function at the next time-step. (Technically, the function can see the
function at the next time-step. (Technically, the function can see the
previous K time-steps of your outputs and L time steps (from the past and
previous K time-steps of your outputs and L time steps (from the past and
future of the sequence) of your inputs.
future of the sequence) of your inputs.
So for example, ``sum()`` could be computed by scanning the ``z+x_i``
So for example, ``sum()`` could be computed by scanning the ``z+x_i``
function over a list, given an initial state of ``z=0``.
function over a list, given an initial state of ``z=0``.
Special cases:
Special cases:
* A *reduce* operation can be performed by returning only the last
* A *reduce* operation can be performed by returning only the last
output of a ``scan``.
output of a ``scan``.
* A *map* operation can be performed by applying a function that
* A *map* operation can be performed by applying a function that
ignores each previous output.
ignores each previous output.
Often a for-loop can be expressed as a ``scan()`` operation, and ``scan`` is
Often a for-loop can be expressed as a ``scan()`` operation, and ``scan`` is
the closest that theano comes to looping. The advantage of using ``scan``
the closest that theano comes to looping. The advantage of using ``scan``
over for loops is that it allows the number of iterations to be a part of the symbolic graph.
over for loops is that it allows the number of iterations to be a part of the symbolic graph.
The Scan Op should typically be used by calling the ``scan()`` function.
The Scan Op should typically be used by calling the ``scan()`` function.
"""
"""
__docformat__
=
'restructedtext en'
__docformat__
=
'restructedtext en'
import
theano
import
theano
...
@@ -30,7 +30,7 @@ from theano.tensor import opt, TensorType
...
@@ -30,7 +30,7 @@ from theano.tensor import opt, TensorType
from
theano
import
gof
,
Apply
from
theano
import
gof
,
Apply
from
theano.compile
import
optdb
from
theano.compile
import
optdb
import
theano.tensor.shared_randomstreams
as
shared_random
import
theano.tensor.shared_randomstreams
as
shared_random
import
copy
import
copy
import
numpy
import
numpy
...
@@ -66,11 +66,11 @@ def hash_listsDictsTuples(x):
...
@@ -66,11 +66,11 @@ def hash_listsDictsTuples(x):
###################################
###################################
## Implement specific function calls : map, reduce, generate
## Implement specific function calls : map, reduce, generate
def
map
(
fn
,
sequences
,
non_sequences
=
[],
n_steps
=
0
,
def
map
(
fn
,
sequences
,
non_sequences
=
[],
n_steps
=
0
,
truncate_gradient
=
-
1
,
go_backwards
=
False
,
truncate_gradient
=
-
1
,
go_backwards
=
False
,
mode
=
'FAST_RUN'
):
mode
=
'FAST_RUN'
):
return
scan
(
fn
,
sequences
=
sequences
,
outputs_info
=
[],
non_sequences
=
non_sequences
,
return
scan
(
fn
,
sequences
=
sequences
,
outputs_info
=
[],
non_sequences
=
non_sequences
,
truncate_gradient
=
truncate_gradient
,
truncate_gradient
=
truncate_gradient
,
go_backwards
=
go_backwards
,
mode
=
mode
)
go_backwards
=
go_backwards
,
mode
=
mode
)
...
@@ -88,28 +88,28 @@ def map(fn, sequences, non_sequences = [], n_steps =0,
...
@@ -88,28 +88,28 @@ def map(fn, sequences, non_sequences = [], n_steps =0,
# z - a sequence that we need two previous values of
# z - a sequence that we need two previous values of
# and we want z to be computed inplace using the storage of 'a'.
# and we want z to be computed inplace using the storage of 'a'.
#
#
# scan(fn, [dict(input=a, taps=[-1,0,1])],
# scan(fn, [dict(input=a, taps=[-1,0,1])],
# [dict(initial=x_init, taps=[-1], ????????),
# [dict(initial=x_init, taps=[-1], ????????),
# None
# None
# dict(initial=z_init, taps=[-2,-1], inplace=a,)])
# dict(initial=z_init, taps=[-2,-1], inplace=a,)])
#
#
# QUESTION:
# QUESTION:
# If the larger (in absolute values) the sequence_taps, the shorter the output
# If the larger (in absolute values) the sequence_taps, the shorter the output
# right? If the sequence_taps = {0: [-10, 10]}, and I pass an input with 22
# right? If the sequence_taps = {0: [-10, 10]}, and I pass an input with 22
# rows, then the scan will output something of length <=2 right?
# rows, then the scan will output something of length <=2 right?
#
#
# ANSWER:
# ANSWER:
# Yes, actually it will be exactly 2 ( if there are no other constraints)
# Yes, actually it will be exactly 2 ( if there are no other constraints)
def
scan
(
fn
,
sequences
=
[],
outputs_info
=
[],
non_sequences
=
[],
def
scan
(
fn
,
sequences
=
[],
outputs_info
=
[],
non_sequences
=
[],
n_steps
=
0
,
truncate_gradient
=
-
1
,
go_backwards
=
False
,
n_steps
=
0
,
truncate_gradient
=
-
1
,
go_backwards
=
False
,
mode
=
None
):
mode
=
None
):
'''Function that constructs and applies a Scan op
'''Function that constructs and applies a Scan op
:param fn:
:param fn:
Function that describes the operations involved in one step of scan
Function that describes the operations involved in one step of scan
Given variables representing all the slices of input and past values of
Given variables representing all the slices of input and past values of
outputs and other non sequences parameters, ``fn`` should produce
outputs and other non sequences parameters, ``fn`` should produce
...
@@ -117,17 +117,17 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -117,17 +117,17 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
which the argument to this function are given is very important. You
which the argument to this function are given is very important. You
should have the following order:
should have the following order:
* all time slices of the first sequence (as given in the
* all time slices of the first sequence (as given in the
``sequences`` list) ordered in the same fashion as the time taps provided
``sequences`` list) ordered in the same fashion as the time taps provided
* all time slices of the second sequence (as given in the
* all time slices of the second sequence (as given in the
``sequences`` list) ordered in the same fashion as the time taps provided
``sequences`` list) ordered in the same fashion as the time taps provided
* ...
* ...
* all time slices of the first output (as given in the
* all time slices of the first output (as given in the
``initial_state`` list) ordered in the same fashion as the time taps provided
``initial_state`` list) ordered in the same fashion as the time taps provided
* all time slices of the second otuput (as given in the
* all time slices of the second otuput (as given in the
``initial_state`` list) ordered in the same fashion as the time taps provided
``initial_state`` list) ordered in the same fashion as the time taps provided
* ...
* ...
* all other parameters over which scan doesn't iterate given
* all other parameters over which scan doesn't iterate given
in the same order as in ``non_sequences`` If you are using shared
in the same order as in ``non_sequences`` If you are using shared
variables over which you do not want to iterate, you do not need to
variables over which you do not want to iterate, you do not need to
...
@@ -135,51 +135,51 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -135,51 +135,51 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
function should return the outputs after each step plus the updates for
function should return the outputs after each step plus the updates for
any of the shared variables. You can either return only outputs or only
any of the shared variables. You can either return only outputs or only
updates. If you have both outputs and updates the function should return
updates. If you have both outputs and updates the function should return
them as a tuple : (outputs, updates) or (updates, outputs).
them as a tuple : (outputs, updates) or (updates, outputs).
Outputs can be just a theano expression if you have only one outputs or
Outputs can be just a theano expression if you have only one outputs or
a list of theano expressions. Updates can be given either as a list of tuples or
a list of theano expressions. Updates can be given either as a list of tuples or
as a dictionary. If you have a list of outputs, the order of these
as a dictionary. If you have a list of outputs, the order of these
should match that of their ``initial_states``.
should match that of their ``initial_states``.
:param sequences:
:param sequences:
list of Theano variables or dictionaries containing Theano variables over which
list of Theano variables or dictionaries containing Theano variables over which
scan needs to iterate. The reason you might want to wrap a certain Theano
scan needs to iterate. The reason you might want to wrap a certain Theano
variable in a dictionary is to provide auxiliary information about how to iterate
variable in a dictionary is to provide auxiliary information about how to iterate
over that variable. For example this is how you specify that you want to use
over that variable. For example this is how you specify that you want to use
several time slices of this sequence at each iteration step. The dictionary
several time slices of this sequence at each iteration step. The dictionary
should have the following keys :
should have the following keys :
* ``input`` -- Theano variable representing the sequence
* ``input`` -- Theano variable representing the sequence
* ``taps`` -- temporal taps to use for this sequence. They are given as a list
* ``taps`` -- temporal taps to use for this sequence. They are given as a list
of ints, where a value ``k`` means that at iteration step ``t`` scan needs to
of ints, where a value ``k`` means that at iteration step ``t`` scan needs to
provide also the slice ``t+k`` The order in which you provide these int values
provide also the slice ``t+k`` The order in which you provide these int values
here is the same order in which the slices will be provided to ``fn``.
here is the same order in which the slices will be provided to ``fn``.
If you do not wrap a variable around a dictionary, scan will do it for you, under
If you do not wrap a variable around a dictionary, scan will do it for you, under
the assumption that you use only one slice, defined as a tap of offset 0. This
the assumption that you use only one slice, defined as a tap of offset 0. This
means that at step ``t`` scan will provide the slice at position ``t``.
means that at step ``t`` scan will provide the slice at position ``t``.
:param outputs_info:
:param outputs_info:
list of Theano variables or dictionaries containing Theano variables used
list of Theano variables or dictionaries containing Theano variables used
to initialize the outputs of scan. As before (for ``sequences``) the reason
to initialize the outputs of scan. As before (for ``sequences``) the reason
you would wrap a Theano variable in a dictionary is to provide additional
you would wrap a Theano variable in a dictionary is to provide additional
information about how scan should deal with that specific output. The dictionary
information about how scan should deal with that specific output. The dictionary
should contain the following keys:
should contain the following keys:
* ``initial`` -- Theano variable containing the initial state of the output
* ``initial`` -- Theano variable containing the initial state of the output
* ``taps`` -- temporal taps to use for this output. The taps are given as a
* ``taps`` -- temporal taps to use for this output. The taps are given as a
list of ints (only negative .. since you can not use future values of outputs),
list of ints (only negative .. since you can not use future values of outputs),
with the same meaning as for ``sequences`` (see above).
with the same meaning as for ``sequences`` (see above).
* ``inplace`` -- theano variable pointing to one of the input sequences; this
* ``inplace`` -- theano variable pointing to one of the input sequences; this
flag tells scan that the output should be computed in the memory spaced occupied
flag tells scan that the output should be computed in the memory spaced occupied
by that input sequence. Note that scan will only do this if allowed by the
by that input sequence. Note that scan will only do this if allowed by the
rest of your computational graph.
rest of your computational graph.
If the function applied recursively uses only the
If the function applied recursively uses only the
previous value of the output, the initial state should have
previous value of the output, the initial state should have
same shape as one time step of the output; otherwise, the initial state
same shape as one time step of the output; otherwise, the initial state
should have the same number of dimension as output. This is easily
should have the same number of dimension as output. This is easily
understood through an example. For computing ``y[t]`` let us assume that we
understood through an example. For computing ``y[t]`` let us assume that we
need ``y[t-1]``, ``y[t-2]`` and ``y[t-4]``. Through an abuse of
need ``y[t-1]``, ``y[t-2]`` and ``y[t-4]``. Through an abuse of
notation, when ``t = 0``, we would need values for ``y[-1]``, ``y[-2]``
notation, when ``t = 0``, we would need values for ``y[-1]``, ``y[-2]``
...
@@ -189,28 +189,28 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -189,28 +189,28 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
case is 4. If ``init_y`` is the variable containing the initial state
case is 4. If ``init_y`` is the variable containing the initial state
of ``y``, then ``init_y[0]`` corresponds to ``y[-4]``, ``init_y[1]``
of ``y``, then ``init_y[0]`` corresponds to ``y[-4]``, ``init_y[1]``
corresponds to ``y[-3]``, ``init_y[2]`` corresponds to ``y[-2]``,
corresponds to ``y[-3]``, ``init_y[2]`` corresponds to ``y[-2]``,
``init_y[3]`` corresponds to ``y[-1]``. The default behaviour of scan is
``init_y[3]`` corresponds to ``y[-1]``. The default behaviour of scan is
the following :
the following :
* if you do not wrap an output in a dictionary, scan will wrap it for you
* if you do not wrap an output in a dictionary, scan will wrap it for you
assuming that you use only the last step of the output ( i.e. it makes your tap
assuming that you use only the last step of the output ( i.e. it makes your tap
value list equal to [-1]) and that it is not computed inplace
value list equal to [-1]) and that it is not computed inplace
* if you wrap an output in a dictionary but you do not provide any taps, but
* if you wrap an output in a dictionary but you do not provide any taps, but
you provide an initial state it will assume that you are using only a tap value
you provide an initial state it will assume that you are using only a tap value
of -1
of -1
* if you wrap an output in a dictionary but you do not provide any initial state,
* if you wrap an output in a dictionary but you do not provide any initial state,
it assumes that you are not using any form of taps
it assumes that you are not using any form of taps
:param non_sequences:
:param non_sequences:
Parameters over which scan should not iterate. These parameters are
Parameters over which scan should not iterate. These parameters are
given at each time step to the function applied recursively.
given at each time step to the function applied recursively.
:param n_steps:
:param n_steps:
Number of steps to iterate. If this value is provided scan will run only for
Number of steps to iterate. If this value is provided scan will run only for
this amount of steps (given that the input sequences are sufficiently long).
this amount of steps (given that the input sequences are sufficiently long).
If there is no input sequence (for example in case of a generator network) scan
If there is no input sequence (for example in case of a generator network) scan
will iterate for this number of steps. It can be a theano scalar or a number.
will iterate for this number of steps. It can be a theano scalar or a number.
:param truncate_gradient:
:param truncate_gradient:
Number of steps to use in truncated BPTT. If you compute gradients
Number of steps to use in truncated BPTT. If you compute gradients
...
@@ -222,33 +222,33 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -222,33 +222,33 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
:param go_backwards:
:param go_backwards:
Flag indicating if you should go backwards through the sequences
Flag indicating if you should go backwards through the sequences
:rtype: tuple
:rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a
:return: tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the
Theano variable or a list of Theano variables representing the
outputs of scan. ``updates`` is a dictionary specifying the
outputs of scan. ``updates`` is a dictionary specifying the
updates rules for all shared variables used in the scan
updates rules for all shared variables used in the scan
operation; this dictionary should be pass to ``theano.function``
operation; this dictionary should be pass to ``theano.function``
'''
'''
# check if inputs are just single variables instead of lists
# check if inputs are just single variables instead of lists
if
not
(
type
(
sequences
)
in
(
list
,
tuple
)):
if
not
(
type
(
sequences
)
in
(
list
,
tuple
)):
seqs
=
[
sequences
]
seqs
=
[
sequences
]
else
:
else
:
seqs
=
sequences
seqs
=
sequences
print
outputs_info
print
outputs_info
if
not
(
type
(
outputs_info
)
in
(
list
,
tuple
)):
if
not
(
type
(
outputs_info
)
in
(
list
,
tuple
)):
outs_info
=
[
outputs_info
]
outs_info
=
[
outputs_info
]
else
:
else
:
outs_info
=
outputs_info
outs_info
=
outputs_info
if
not
(
type
(
non_sequences
)
in
(
list
,
tuple
)):
if
not
(
type
(
non_sequences
)
in
(
list
,
tuple
)):
non_seqs
=
[
non_sequences
]
non_seqs
=
[
non_sequences
]
else
:
else
:
non_seqs
=
non_sequences
non_seqs
=
non_sequences
# compute number of sequences and number of outputs
# compute number of sequences and number of outputs
n_seqs
=
len
(
seqs
)
n_seqs
=
len
(
seqs
)
n_outs
=
len
(
outs_info
)
n_outs
=
len
(
outs_info
)
...
@@ -258,7 +258,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -258,7 +258,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# wrap sequences in a dictionary if they are not already
# wrap sequences in a dictionary if they are not already
# in the same pass create a sequences_taps dictionary
# in the same pass create a sequences_taps dictionary
for
i
in
xrange
(
n_seqs
):
for
i
in
xrange
(
n_seqs
):
if
not
type
(
seqs
[
i
])
==
dict
:
if
not
type
(
seqs
[
i
])
==
dict
:
seqs
[
i
]
=
dict
(
input
=
seqs
[
i
],
taps
=
[
0
])
seqs
[
i
]
=
dict
(
input
=
seqs
[
i
],
taps
=
[
0
])
# see if taps values are provided as a list
# see if taps values are provided as a list
elif
seqs
[
i
]
.
get
(
'taps'
,
None
):
elif
seqs
[
i
]
.
get
(
'taps'
,
None
):
...
@@ -278,7 +278,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -278,7 +278,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if
outs_info
[
i
]:
if
outs_info
[
i
]:
if
not
type
(
outs_info
[
i
])
==
dict
:
if
not
type
(
outs_info
[
i
])
==
dict
:
outs_info
[
i
]
=
dict
(
initial
=
outs_info
[
i
],
taps
=
[
-
1
])
outs_info
[
i
]
=
dict
(
initial
=
outs_info
[
i
],
taps
=
[
-
1
])
# if there is no initial state but there are taps
# if there is no initial state but there are taps
elif
(
not
outs_info
[
i
]
.
get
(
'initial'
,
None
))
and
(
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
elif
(
not
outs_info
[
i
]
.
get
(
'initial'
,
None
))
and
(
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
raise
ValueError
(
'If you are using slices of an output you need to '
\
raise
ValueError
(
'If you are using slices of an output you need to '
\
'provide a initial state for it'
,
outs_info
[
i
])
'provide a initial state for it'
,
outs_info
[
i
])
...
@@ -303,13 +303,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -303,13 +303,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
raise
ValueError
(
'Asked to compute in place of a non-input variable'
,
\
raise
ValueError
(
'Asked to compute in place of a non-input variable'
,
\
outs_info
[
i
]
.
get
(
'inplace'
,
None
))
outs_info
[
i
]
.
get
(
'inplace'
,
None
))
# create theano inputs for the recursive function
# create theano inputs for the recursive function
# note : this is a first batch of possible inputs that will
# note : this is a first batch of possible inputs that will
# be compiled in a dummy function; we used this dummy
# be compiled in a dummy function; we used this dummy
# function to detect shared variables and their updates
# function to detect shared variables and their updates
# and to construct a new list of possible inputs
# and to construct a new list of possible inputs
args
=
[]
args
=
[]
dummy_notshared_ins
=
0
dummy_notshared_ins
=
0
dummy_notshared_init_outs
=
0
dummy_notshared_init_outs
=
0
slice_to_seqs
=
[]
slice_to_seqs
=
[]
# go through sequences picking up time slices as needed
# go through sequences picking up time slices as needed
...
@@ -344,7 +344,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -344,7 +344,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# add only the not shared variables to the arguments of the dummy
# add only the not shared variables to the arguments of the dummy
# function [ a function should not get shared variables as input ]
# function [ a function should not get shared variables as input ]
dummy_args
=
args
+
notshared_other_args
dummy_args
=
args
+
notshared_other_args
# arguments for the lambda expression that gives us the output
# arguments for the lambda expression that gives us the output
# of the inner function
# of the inner function
args
+=
non_seqs
args
+=
non_seqs
...
@@ -397,7 +397,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -397,7 +397,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
shared_non_seqs
=
[]
shared_non_seqs
=
[]
givens
=
{}
givens
=
{}
# if the number of outputs to the function does not match the number of
# if the number of outputs to the function does not match the number of
# assumed outputs
# assumed outputs
if
len
(
inner_fn_out_states
)
!=
n_outs
:
if
len
(
inner_fn_out_states
)
!=
n_outs
:
...
@@ -406,9 +406,9 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -406,9 +406,9 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# are required to have any sort of time taps
# are required to have any sort of time taps
# we just need to update the number of actual outputs
# we just need to update the number of actual outputs
n_outs
=
len
(
inner_fn_out_states
)
n_outs
=
len
(
inner_fn_out_states
)
# other updates :
# other updates :
for
i
in
xrange
(
n_outs
):
for
i
in
xrange
(
n_outs
):
outs_info
+=
[
dict
()
]
outs_info
+=
[
dict
()
]
else
:
else
:
raise
ValueError
(
'There has been a terrible mistake in our input arguments'
raise
ValueError
(
'There has been a terrible mistake in our input arguments'
' and scan is totally lost. Make sure that you indicate for every '
' and scan is totally lost. Make sure that you indicate for every '
...
@@ -426,13 +426,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -426,13 +426,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if
isinstance
(
input
.
variable
,
theano
.
compile
.
SharedVariable
)
and
input
.
update
:
if
isinstance
(
input
.
variable
,
theano
.
compile
.
SharedVariable
)
and
input
.
update
:
new_var
=
input
.
variable
.
type
()
new_var
=
input
.
variable
.
type
()
inner_fn_inputs
.
append
(
new_var
)
inner_fn_inputs
.
append
(
new_var
)
val
=
slice_to_seqs
[
-
1
]
if
slice_to_seqs
else
-
1
val
=
slice_to_seqs
[
-
1
]
if
slice_to_seqs
else
-
1
slice_to_seqs
+=
[
val
+
1
]
slice_to_seqs
+=
[
val
+
1
]
inner_fn_out_states
+=
[
input
.
update
]
inner_fn_out_states
+=
[
input
.
update
]
update_map
[
input
.
variable
]
=
n_extended_outs
update_map
[
input
.
variable
]
=
n_extended_outs
outputs_taps
[
n_extended_outs
]
=
[
-
1
]
outputs_taps
[
n_extended_outs
]
=
[
-
1
]
n_extended_outs
+=
1
n_extended_outs
+=
1
store_steps
+=
[
1
]
store_steps
+=
[
1
]
shared_outs
+=
[
input
.
variable
]
shared_outs
+=
[
input
.
variable
]
givens
[
input
.
variable
]
=
inner_fn_inputs
[
-
1
]
givens
[
input
.
variable
]
=
inner_fn_inputs
[
-
1
]
...
@@ -446,34 +446,34 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
...
@@ -446,34 +446,34 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
givens
[
input
.
variable
]
=
inner_fn_inputs
[
-
1
]
givens
[
input
.
variable
]
=
inner_fn_inputs
[
-
1
]
elif
not
isinstance
(
input
.
variable
,
theano
.
compile
.
SharedVariable
):
elif
not
isinstance
(
input
.
variable
,
theano
.
compile
.
SharedVariable
):
inner_fn_inputs
.
append
(
input
.
variable
)
inner_fn_inputs
.
append
(
input
.
variable
)
# Create the Scan op object
# Create the Scan op object
local_op
=
Scan
(
(
inner_fn_inputs
,
inner_fn_out_states
,
givens
,
slice_to_seqs
),
n_seqs
,
local_op
=
Scan
(
(
inner_fn_inputs
,
inner_fn_out_states
,
givens
,
slice_to_seqs
),
n_seqs
,
n_extended_outs
,
inplace_map
,
sequences_taps
,
outputs_taps
,
truncate_gradient
,
n_extended_outs
,
inplace_map
,
sequences_taps
,
outputs_taps
,
truncate_gradient
,
go_backwards
,
store_steps
,
mode
)
go_backwards
,
store_steps
,
mode
)
# Call the object on the input sequences, initial values for outs,
# Call the object on the input sequences, initial values for outs,
# and non sequences
# and non sequences
for
seq
in
seqs
:
for
seq
in
seqs
:
if
not
seq
.
get
(
'input'
,
None
):
if
not
seq
.
get
(
'input'
,
None
):
raiseValue
(
'All input sequences should provide'
)
raiseValue
(
'All input sequences should provide'
)
unwrapped_seqs
=
[
seq
.
get
(
'input'
,
theano
.
tensor
.
as_tensor
(
0.
))
for
seq
in
seqs
]
unwrapped_seqs
=
[
seq
.
get
(
'input'
,
theano
.
tensor
.
as_tensor
(
0.
))
for
seq
in
seqs
]
unwrapped_outs
=
[
out
.
get
(
'initial'
,
theano
.
tensor
.
as_tensor
(
0.
))
for
out
in
outs_info
]
unwrapped_outs
=
[
out
.
get
(
'initial'
,
theano
.
tensor
.
as_tensor
(
0.
))
for
out
in
outs_info
]
values
=
local_op
(
*
(
[
theano
.
tensor
.
as_tensor
(
n_steps
)]
values
=
local_op
(
*
(
[
theano
.
tensor
.
as_tensor
(
n_steps
)]
+
unwrapped_seqs
+
unwrapped_seqs
+
unwrapped_outs
+
unwrapped_outs
+
shared_outs
+
shared_outs
+
notshared_other_args
+
notshared_other_args
+
shared_non_seqs
))
+
shared_non_seqs
))
if
not
type
(
values
)
in
(
tuple
,
list
):
if
not
type
(
values
)
in
(
tuple
,
list
):
values
=
[
values
]
values
=
[
values
]
for
val
in
update_map
.
keys
():
for
val
in
update_map
.
keys
():
update_map
[
val
]
=
values
[
update_map
[
val
]
]
update_map
[
val
]
=
values
[
update_map
[
val
]
]
if
n_outs
==
1
:
if
n_outs
==
1
:
values
=
values
[
0
]
values
=
values
[
0
]
else
:
else
:
values
=
values
[:
n_outs
]
values
=
values
[:
n_outs
]
return
(
values
,
update_map
)
return
(
values
,
update_map
)
...
@@ -493,22 +493,22 @@ class Scan(theano.Op):
...
@@ -493,22 +493,22 @@ class Scan(theano.Op):
mode
=
'FAST_RUN'
,
inplace
=
False
):
mode
=
'FAST_RUN'
,
inplace
=
False
):
'''
'''
:param (inputs,outputs, givens,slice_to_seqs):
:param (inputs,outputs, givens,slice_to_seqs):
inputs and outputs Theano variables that describe the function that is
inputs and outputs Theano variables that describe the function that is
applied recursively; givens list is used to replace shared
applied recursively; givens list is used to replace shared
variables with not shared ones; slice_to_seqs is a convinience list that
variables with not shared ones; slice_to_seqs is a convinience list that
tells which of the inputs is slice to which of the sequences
tells which of the inputs is slice to which of the sequences
:param n_seqs: number of sequences over which scan will have to
:param n_seqs: number of sequences over which scan will have to
iterate
iterate
:param n_outs: number of outputs of the scan op
:param n_outs: number of outputs of the scan op
:param inplace_map: see scan function above
:param inplace_map: see scan function above
:param seqs_taps: see scan function above
:param seqs_taps: see scan function above
:param outs_taps: see scan function above
:param outs_taps: see scan function above
:param truncate_gradient: number of steps after which scan should
:param truncate_gradient: number of steps after which scan should
truncate -1 implies no truncation
truncate -1 implies no truncation
:param go_bacwards: see scan funcion above
:param go_bacwards: see scan funcion above
:param store_steps:
:param store_steps:
a list of booleans of same size as the number of outputs; the value at position
a list of booleans of same size as the number of outputs; the value at position
``i`` in the list corresponds to the ``i-th`` output, and it tells how many
``i`` in the list corresponds to the ``i-th`` output, and it tells how many
steps (from the end towards the begining) of the outputs you really need and should
steps (from the end towards the begining) of the outputs you really need and should
return; given this information, scan can know (if possible) to allocate only
return; given this information, scan can know (if possible) to allocate only
the amount of memory needed to compute that many entries
the amount of memory needed to compute that many entries
...
@@ -599,7 +599,7 @@ class Scan(theano.Op):
...
@@ -599,7 +599,7 @@ class Scan(theano.Op):
(
self
.
n_outs
==
other
.
n_outs
)
and
\
(
self
.
n_outs
==
other
.
n_outs
)
and
\
(
self
.
n_args
==
other
.
n_args
)
(
self
.
n_args
==
other
.
n_args
)
return
rval
return
rval
def
__hash__
(
self
):
def
__hash__
(
self
):
# the self.apply_output_types are a function of all these things
# the self.apply_output_types are a function of all these things
...
@@ -624,23 +624,23 @@ class Scan(theano.Op):
...
@@ -624,23 +624,23 @@ class Scan(theano.Op):
The args are packed like this:
The args are packed like this:
n_steps
n_steps
X sequence inputs x_1, x_2, ... x_<self.n_seqs>
X sequence inputs x_1, x_2, ... x_<self.n_seqs>
Y initial states (u_1, u_2, ... u_<self.n_outs>) for our outputs. Each must have appropriate length (T_1, T_2, ..., T_Y).
Y initial states (u_1, u_2, ... u_<self.n_outs>) for our outputs. Each must have appropriate length (T_1, T_2, ..., T_Y).
W other inputs w_1, w_2, ... w_W
W other inputs w_1, w_2, ... w_W
There are at least 1 + self.n_seqs + self.n_outs inputs, and the ones above this number
There are at least 1 + self.n_seqs + self.n_outs inputs, and the ones above this number
are passed to the scanned function as non-sequential inputs.
are passed to the scanned function as non-sequential inputs.
The outputs are more straightforward:
The outputs are more straightforward:
Y sequence outputs y_1, y_2, ... y_<self.n_outs>
Y sequence outputs y_1, y_2, ... y_<self.n_outs>
"""
"""
n_steps
=
0
n_steps
=
0
if
(
self
.
n_seqs
==
0
)
and
(
args
[
0
]
==
0
):
if
(
self
.
n_seqs
==
0
)
and
(
args
[
0
]
==
0
):
raise
ValueError
(
'Scan does not know over how many steps it '
raise
ValueError
(
'Scan does not know over how many steps it '
'should iterate! No input sequence or number of steps to '
'should iterate! No input sequence or number of steps to '
...
@@ -648,31 +648,31 @@ class Scan(theano.Op):
...
@@ -648,31 +648,31 @@ class Scan(theano.Op):
if
(
args
[
0
]
!=
0
):
if
(
args
[
0
]
!=
0
):
n_steps
=
args
[
0
]
n_steps
=
args
[
0
]
for
i
in
xrange
(
self
.
n_seqs
):
for
i
in
xrange
(
self
.
n_seqs
):
if
self
.
seqs_taps
.
has_key
(
i
):
if
self
.
seqs_taps
.
has_key
(
i
):
# compute actual length of the sequence ( we need to see what
# compute actual length of the sequence ( we need to see what
# past taps this sequence has, and leave room for them
# past taps this sequence has, and leave room for them
seq_len
=
args
[
i
+
1
]
.
shape
[
0
]
+
min
(
self
.
seqs_taps
[
i
])
seq_len
=
args
[
i
+
1
]
.
shape
[
0
]
+
min
(
self
.
seqs_taps
[
i
])
if
max
(
self
.
seqs_taps
[
i
])
>
0
:
if
max
(
self
.
seqs_taps
[
i
])
>
0
:
# using future values, so need to end the sequence earlier
# using future values, so need to end the sequence earlier
seq_len
-=
max
(
self
.
seqs_taps
[
i
])
seq_len
-=
max
(
self
.
seqs_taps
[
i
])
if
n_steps
==
0
:
if
n_steps
==
0
:
# length of the sequences, leaving room for the largest
# length of the sequences, leaving room for the largest
n_steps
=
seq_len
n_steps
=
seq_len
if
seq_len
!=
n_steps
:
if
seq_len
!=
n_steps
:
warning
((
'Input sequence
%
d has a shorter length then the '
warning
((
'Input sequence
%
d has a shorter length then the '
'expected number of steps
%
d'
)
%
(
i
,
n_steps
))
'expected number of steps
%
d'
)
%
(
i
,
n_steps
))
n_steps
=
min
(
seq_len
,
n_steps
)
n_steps
=
min
(
seq_len
,
n_steps
)
# check if we deal with an inplace operation
# check if we deal with an inplace operation
inplace_map
=
self
.
inplace_map
inplace_map
=
self
.
inplace_map
if
not
self
.
inplace
:
#if it was not optimized to work inplace
if
not
self
.
inplace
:
#if it was not optimized to work inplace
inplace_map
=
{}
inplace_map
=
{}
# check lengths of init_outs
# check lengths of init_outs
for
i
in
xrange
(
self
.
n_seqs
+
1
,
self
.
n_seqs
+
self
.
n_outs
+
1
):
for
i
in
xrange
(
self
.
n_seqs
+
1
,
self
.
n_seqs
+
self
.
n_outs
+
1
):
if
self
.
outs_taps
.
has_key
(
i
-
self
.
n_seqs
-
1
):
if
self
.
outs_taps
.
has_key
(
i
-
self
.
n_seqs
-
1
):
...
@@ -682,10 +682,10 @@ class Scan(theano.Op):
...
@@ -682,10 +682,10 @@ class Scan(theano.Op):
warning
((
'Initial state for output
%
d has fewer values then '
warning
((
'Initial state for output
%
d has fewer values then '
'required by the maximal past value
%
d. Scan will use 0s'
'required by the maximal past value
%
d. Scan will use 0s'
' for missing values'
)
%
(
i
-
self
.
n_iterable
-
1
,
req_size
))
' for missing values'
)
%
(
i
-
self
.
n_iterable
-
1
,
req_size
))
self
.
n_steps
=
n_steps
self
.
n_steps
=
n_steps
y
=
self
.
scan
(
self
.
fn
,
args
[
1
:],
self
.
n_seqs
,
self
.
n_outs
,
y
=
self
.
scan
(
self
.
fn
,
args
[
1
:],
self
.
n_seqs
,
self
.
n_outs
,
self
.
seqs_taps
,
self
.
outs_taps
,
n_steps
,
self
.
go_backwards
,
self
.
seqs_taps
,
self
.
outs_taps
,
n_steps
,
self
.
go_backwards
,
inplace_map
)
inplace_map
)
'''
'''
...
@@ -698,22 +698,22 @@ class Scan(theano.Op):
...
@@ -698,22 +698,22 @@ class Scan(theano.Op):
outs[i][0] = y[i]
outs[i][0] = y[i]
'''
'''
for
i
in
xrange
(
self
.
n_outs
):
for
i
in
xrange
(
self
.
n_outs
):
if
self
.
store_steps
[
i
]
>
1
:
if
self
.
store_steps
[
i
]
>
1
:
# we need to reorder the steps .. to have them in the correct order
# we need to reorder the steps .. to have them in the correct order
# we use numpy advanced indexing for this
# we use numpy advanced indexing for this
# index order :
# index order :
index_order
=
range
(
self
.
idx_store_steps
[
i
],
self
.
store_steps
[
i
])
+
\
index_order
=
range
(
self
.
idx_store_steps
[
i
],
self
.
store_steps
[
i
])
+
\
range
(
self
.
idx_store_steps
[
i
])
range
(
self
.
idx_store_steps
[
i
])
outs
[
i
][
0
]
=
y
[
i
][
index_order
]
outs
[
i
][
0
]
=
y
[
i
][
index_order
]
else
:
else
:
outs
[
i
][
0
]
=
y
[
i
]
outs
[
i
][
0
]
=
y
[
i
]
def
scan
(
self
,
fn
,
args
,
n_seqs
,
n_outs
,
seqs_taps
,
outs_taps
,
n_steps
,
go_backwards
,
inplace_map
):
def
scan
(
self
,
fn
,
args
,
n_seqs
,
n_outs
,
seqs_taps
,
outs_taps
,
n_steps
,
go_backwards
,
inplace_map
):
''' Actual loop of the scap op perform function '''
''' Actual loop of the scap op perform function '''
# Note that we removed the n_steps from the args for this function, so the
# Note that we removed the n_steps from the args for this function, so the
# order of arguments is slightly different compared to perform
# order of arguments is slightly different compared to perform
y
=
[]
y
=
[]
# When you have taps, you need to leave borders in your sequences, initial outputs
# When you have taps, you need to leave borders in your sequences, initial outputs
# for those taps; here we compute what are those borders for sequences
# for those taps; here we compute what are those borders for sequences
...
@@ -724,7 +724,7 @@ class Scan(theano.Op):
...
@@ -724,7 +724,7 @@ class Scan(theano.Op):
# create storage space for the outputs ( using corresponding inputs if we are
# create storage space for the outputs ( using corresponding inputs if we are
# dealing with inplace operations
# dealing with inplace operations
# `idx_store_steps` is a dictionary telling us the current position in y of an
# `idx_store_steps` is a dictionary telling us the current position in y of an
# output where we want to store only the last k steps
# output where we want to store only the last k steps
...
@@ -735,7 +735,7 @@ class Scan(theano.Op):
...
@@ -735,7 +735,7 @@ class Scan(theano.Op):
seqs_taps
[
inplace_map
[
i
]]
>=
0
:
seqs_taps
[
inplace_map
[
i
]]
>=
0
:
y
+=
[
args
[
inplace_map
[
i
]][:
n_steps
]]
y
+=
[
args
[
inplace_map
[
i
]][:
n_steps
]]
else
:
else
:
# check if you are using past value .. through in a warning and do not
# check if you are using past value .. through in a warning and do not
# work inplace
# work inplace
if
inplace_map
.
has_key
(
i
)
and
seqs_taps
.
has_key
(
inplace_map
[
i
])
and
seqs_taps
[
inplace_map
[
i
]]
<
0
:
if
inplace_map
.
has_key
(
i
)
and
seqs_taps
.
has_key
(
inplace_map
[
i
])
and
seqs_taps
[
inplace_map
[
i
]]
<
0
:
warning
(
'Can not work inplace because of past values'
)
warning
(
'Can not work inplace because of past values'
)
...
@@ -773,7 +773,7 @@ class Scan(theano.Op):
...
@@ -773,7 +773,7 @@ class Scan(theano.Op):
_i
=
i
_i
=
i
if
go_backwards
:
if
go_backwards
:
_i
=
n_steps
-
1
-
i
_i
=
n_steps
-
1
-
i
# collect data from sequences
# collect data from sequences
for
j
in
xrange
(
n_seqs
):
for
j
in
xrange
(
n_seqs
):
# get borders
# get borders
if
seqs_taps
.
has_key
(
j
):
if
seqs_taps
.
has_key
(
j
):
...
@@ -793,13 +793,13 @@ class Scan(theano.Op):
...
@@ -793,13 +793,13 @@ class Scan(theano.Op):
for
tap_value
in
ls_taps
:
for
tap_value
in
ls_taps
:
if
i
+
tap_value
<
0
:
if
i
+
tap_value
<
0
:
if
sz
<
1
:
if
sz
<
1
:
# this is a special case, when our initial state has no
# this is a special case, when our initial state has no
# temporal dimension
# temporal dimension
fn_args
+=
[
args
[
j
+
n_seqs
]
]
fn_args
+=
[
args
[
j
+
n_seqs
]
]
else
:
else
:
k
=
i
+
sz
+
tap_value
k
=
i
+
sz
+
tap_value
if
k
<
0
:
if
k
<
0
:
# past value not provided.. issue a warning and use 0s of the
# past value not provided.. issue a warning and use 0s of the
# correct dtype
# correct dtype
fn_args
+=
[
numpy
.
zeros
(
args
[
j
+
n_seqs
][
0
]
.
shape
,
dtype
=
fn_args
+=
[
numpy
.
zeros
(
args
[
j
+
n_seqs
][
0
]
.
shape
,
dtype
=
args
[
j
+
n_sqs
][
0
]
.
dtype
)]
args
[
j
+
n_sqs
][
0
]
.
dtype
)]
...
@@ -815,8 +815,8 @@ class Scan(theano.Op):
...
@@ -815,8 +815,8 @@ class Scan(theano.Op):
# just the last one
# just the last one
fn_args
+=
[
y
[
j
]
]
fn_args
+=
[
y
[
j
]
]
else
:
else
:
# storing only the last k
# storing only the last k
# get what idx we want
# get what idx we want
req_idx
=
(
self
.
idx_store_steps
[
j
]
+
tap_value
+
self
.
store_steps
[
j
])
req_idx
=
(
self
.
idx_store_steps
[
j
]
+
tap_value
+
self
.
store_steps
[
j
])
# we need this modula self.store_steps[j]
# we need this modula self.store_steps[j]
req_idx
=
req_idx
%
self
.
store_steps
[
j
]
req_idx
=
req_idx
%
self
.
store_steps
[
j
]
...
@@ -832,7 +832,7 @@ class Scan(theano.Op):
...
@@ -832,7 +832,7 @@ class Scan(theano.Op):
# if you have provided no size for the missing output you might find yourself
# if you have provided no size for the missing output you might find yourself
# here with a incorect array .. if that happens realocate memory for the
# here with a incorect array .. if that happens realocate memory for the
# needed array
# needed array
try
:
try
:
if
hasattr
(
something
[
j
],
'dtype'
)
and
(
y
[
j
]
.
dtype
!=
something
[
j
]
.
dtype
)
:
if
hasattr
(
something
[
j
],
'dtype'
)
and
(
y
[
j
]
.
dtype
!=
something
[
j
]
.
dtype
)
:
raise
ValueError
(
'wrong dtype'
)
raise
ValueError
(
'wrong dtype'
)
...
@@ -864,24 +864,24 @@ class Scan(theano.Op):
...
@@ -864,24 +864,24 @@ class Scan(theano.Op):
return
y
return
y
def
grad
(
self
,
args
,
g_outs
):
def
grad
(
self
,
args
,
g_outs
):
raise
NotImplementedError
(
'This will be implemented in the near future'
);
raise
NotImplementedError
(
'This will be implemented in the near future'
);
'''
'''
if True:
if True:
#((self.updates.keys() != []) or (self.inplace_map.keys() != [])
\
#((self.updates.keys() != []) or (self.inplace_map.keys() != [])
\
# or numpy.any(self.store_steps)):
# or numpy.any(self.store_steps)):
# warning('Can not compute gradients if inplace or updates '
\
# warning('Can not compute gradients if inplace or updates '
\
# 'are used or if you do not keep past value of outputs.'
\
# 'are used or if you do not keep past value of outputs.'
\
# 'Use force_gradient if you know for sure '
\
# 'Use force_gradient if you know for sure '
\
# 'that the gradient can be computed automatically.')
# 'that the gradient can be computed automatically.')
warning('Gradient not fully tested yet !')
warning('Gradient not fully tested yet !')
return [None for i in args]
return [None for i in args]
else:
else:
# forward pass
# forward pass
y = self(*args)
y = self(*args)
if not( type(y) in (list,tuple)):
if not( type(y) in (list,tuple)):
y = [y]
y = [y]
g_y = [outputs[0].type()]
g_y = [outputs[0].type()]
def compute_gradient(y, g_y):
def compute_gradient(y, g_y):
...
@@ -893,11 +893,11 @@ class Scan(theano.Op):
...
@@ -893,11 +893,11 @@ class Scan(theano.Op):
theano._asarray(0,dtype = p.type.dtype))
theano._asarray(0,dtype = p.type.dtype))
return [gmap.get(p, zero(p)) for p in inputs]
return [gmap.get(p, zero(p)) for p in inputs]
i = 0
i = 0
while
while
g_args = compute_gradient( outputs[0], g_y[-1])
g_args = compute_gradient( outputs[0], g_y[-1])
# for all outputs compute gradients and then sum them up
# for all outputs compute gradients and then sum them up
for y in outputs[1:]:
for y in outputs[1:]:
g_y += [y.type()]
g_y += [y.type()]
...
@@ -905,8 +905,8 @@ class Scan(theano.Op):
...
@@ -905,8 +905,8 @@ class Scan(theano.Op):
for i in xrange(len(g_args)):
for i in xrange(len(g_args)):
g_args[i] += g_args_y[i]
g_args[i] += g_args_y[i]
self.g_ins = g_y+inputs
self.g_ins = g_y+inputs
self.g_outs = g_args
self.g_outs = g_args
...
@@ -915,13 +915,13 @@ class Scan(theano.Op):
...
@@ -915,13 +915,13 @@ class Scan(theano.Op):
if g_outs[i] == None:
if g_outs[i] == None:
g_outs[i] = theano.tensor.zeros_like(y[i])
g_outs[i] = theano.tensor.zeros_like(y[i])
g_args = [self.n_steps]+g_outs + y
g_args = [self.n_steps]+g_outs + y
# check if go_backwards is true
# check if go_backwards is true
if self.go_backwards:
if self.go_backwards:
for seq in args[1:self.n_seqs]:
for seq in args[1:self.n_seqs]:
g_args += [seq[::-1]]
g_args += [seq[::-1]]
else:
else:
g_args += args[1:self.n_seqs]
g_args += args[1:self.n_seqs]
g_args += args[1+self.n_seqs: ]
g_args += args[1+self.n_seqs: ]
...
@@ -938,13 +938,13 @@ class Scan(theano.Op):
...
@@ -938,13 +938,13 @@ class Scan(theano.Op):
def
scan_make_inplace
(
node
):
def
scan_make_inplace
(
node
):
op
=
node
.
op
op
=
node
.
op
if
isinstance
(
op
,
Scan
)
and
(
not
op
.
inplace
)
and
(
op
.
inplace_map
.
keys
()
!=
[]):
if
isinstance
(
op
,
Scan
)
and
(
not
op
.
inplace
)
and
(
op
.
inplace_map
.
keys
()
!=
[]):
return
Scan
((
op
.
inputs
,
op
.
outputs
,
op
.
givens
,
op
.
slice_to_seqs
)
,
op
.
n_seqs
,
return
Scan
((
op
.
inputs
,
op
.
outputs
,
op
.
givens
,
op
.
slice_to_seqs
)
,
op
.
n_seqs
,
op
.
n_outs
,
op
.
inplace_map
,
op
.
seqs_taps
,
op
.
outs_taps
,
op
.
n_outs
,
op
.
inplace_map
,
op
.
seqs_taps
,
op
.
outs_taps
,
op
.
truncate_gradient
,
op
.
go_backwards
,
op
.
store_steps
,
op
.
truncate_gradient
,
op
.
go_backwards
,
op
.
store_steps
,
inplace
=
True
)
.
make_node
(
*
node
.
inputs
)
.
outputs
inplace
=
True
)
.
make_node
(
*
node
.
inputs
)
.
outputs
return
False
return
False
optdb
.
register
(
'scanOp_make_inplace'
,
opt
.
in2out
(
scan_make_inplace
,
optdb
.
register
(
'scanOp_make_inplace'
,
opt
.
in2out
(
scan_make_inplace
,
ignore_newtrees
=
True
),
75
,
'fast_run'
,
'inplace'
)
ignore_newtrees
=
True
),
75
,
'fast_run'
,
'inplace'
)
...
@@ -953,7 +953,7 @@ optdb.register('scanOp_make_inplace', opt.in2out(scan_make_inplace,
...
@@ -953,7 +953,7 @@ optdb.register('scanOp_make_inplace', opt.in2out(scan_make_inplace,
'''
'''
class ScanGrad(theano.Op):
class ScanGrad(theano.Op):
"""Gradient Op for Scan"""
"""Gradient Op for Scan"""
def __init__(self,(g_ins, g_outs) , n_seqs, n_outs,
def __init__(self,(g_ins, g_outs) , n_seqs, n_outs,
seqs_taps = {}, outs_taps= {}, truncate_gradient = -1):
seqs_taps = {}, outs_taps= {}, truncate_gradient = -1):
self.grad_fn = theano.function(g_ins, g_outs)
self.grad_fn = theano.function(g_ins, g_outs)
self.inputs = g_ins
self.inputs = g_ins
...
@@ -966,7 +966,7 @@ class ScanGrad(theano.Op):
...
@@ -966,7 +966,7 @@ class ScanGrad(theano.Op):
self.destroy_map = {}
self.destroy_map = {}
def __eq__(self,other):
def __eq__(self,other):
rval = type(self) == type(other)
rval = type(self) == type(other)
if rval:
if rval:
rval = (self.inputs == other.inputs) and
\
rval = (self.inputs == other.inputs) and
\
...
@@ -975,7 +975,7 @@ class ScanGrad(theano.Op):
...
@@ -975,7 +975,7 @@ class ScanGrad(theano.Op):
(self.n_outs == other.n_outs) and
\
(self.n_outs == other.n_outs) and
\
(self.truncate_gradient == other.truncate_gradient) and
\
(self.truncate_gradient == other.truncate_gradient) and
\
(self.seqs_taps == other.seqs_taps) and
\
(self.seqs_taps == other.seqs_taps) and
\
(self.outs_taps == other.outs_taps)
(self.outs_taps == other.outs_taps)
return rval
return rval
def __hash__(self):
def __hash__(self):
...
@@ -989,10 +989,10 @@ class ScanGrad(theano.Op):
...
@@ -989,10 +989,10 @@ class ScanGrad(theano.Op):
hash_dict(self.outs_taps)
hash_dict(self.outs_taps)
def make_node(self, *args):
def make_node(self, *args):
# input of the gradient op :
# input of the gradient op :
# | g_outs | y | seqs | outs | non_seqs |
# | g_outs | y | seqs | outs | non_seqs |
# | n_outs | n_outs | n_seqs | n_outs | unknown |
# | n_outs | n_outs | n_seqs | n_outs | unknown |
# return
# return
# | grad of seqs | grad of outs | grad of non_seqs |
# | grad of seqs | grad of outs | grad of non_seqs |
# | n_seqs | n_outs | unknown |
# | n_seqs | n_outs | unknown |
return theano.Apply(self, list(args),
return theano.Apply(self, list(args),
...
@@ -1005,8 +1005,8 @@ class ScanGrad(theano.Op):
...
@@ -1005,8 +1005,8 @@ class ScanGrad(theano.Op):
seqs = inputs[:self.n_seqs]
seqs = inputs[:self.n_seqs]
seeds = inputs[self.n_seqs:self.n_seqs+self.n_outs]
seeds = inputs[self.n_seqs:self.n_seqs+self.n_outs]
non_seqs = inputs[self.n_outs+self.n_seqs:]
non_seqs = inputs[self.n_outs+self.n_seqs:]
# generate space for gradient
# generate space for gradient
g_seqs = [numpy.zeros_like(k) for k in seqs]
g_seqs = [numpy.zeros_like(k) for k in seqs]
g_seeds = [numpy.zeros_like(k) for k in seeds]
g_seeds = [numpy.zeros_like(k) for k in seeds]
g_non_seqs = [numpy.zeros_like(k) for k in non_seqs]
g_non_seqs = [numpy.zeros_like(k) for k in non_seqs]
...
@@ -1045,7 +1045,7 @@ class ScanGrad(theano.Op):
...
@@ -1045,7 +1045,7 @@ class ScanGrad(theano.Op):
_ins = []
_ins = []
for j in xrange(self.n_seqs):
for j in xrange(self.n_seqs):
if self.seqs_taps.has_key(j):
if self.seqs_taps.has_key(j):
ls_taps = self.seqs_taps[j]
ls_taps = self.seqs_taps[j]
min_tap = seqs_mins[j]
min_tap = seqs_mins[j]
for tap_value in ls_taps:
for tap_value in ls_taps:
k = i - min_tap + tap_value
k = i - min_tap + tap_value
...
@@ -1073,8 +1073,8 @@ class ScanGrad(theano.Op):
...
@@ -1073,8 +1073,8 @@ class ScanGrad(theano.Op):
g_out = [arg[i] for arg in g_outs]
g_out = [arg[i] for arg in g_outs]
grad_args = g_out + _ins + _outs + non_seqs
grad_args = g_out + _ins + _outs + non_seqs
grads=self.grad_fn(*grad_args)
grads=self.grad_fn(*grad_args)
# get gradient for inputs
# get gradient for inputs
pos = 0
pos = 0
for j in xrange(self.n_seqs):
for j in xrange(self.n_seqs):
if self.seqs_taps.has_key(j):
if self.seqs_taps.has_key(j):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论