Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f4b29ab6
提交
f4b29ab6
authored
10月 28, 2011
作者:
nouiz
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #166 from pascanur/scan_grad_nsteps
Fix a bug of grad reported by Michael Forbes
上级
5dd4c64a
2abc014b
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
274 行增加
和
261 行删除
+274
-261
scan.py
theano/scan_module/scan.py
+254
-261
test_scan.py
theano/scan_module/tests/test_scan.py
+20
-0
没有找到文件。
theano/scan_module/scan.py
浏览文件 @
f4b29ab6
...
@@ -34,10 +34,10 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
...
@@ -34,10 +34,10 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
``foldr()``.
``foldr()``.
"""
"""
__docformat__
=
'restructedtext en'
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"Frederic Bastien "
"James Bergstra "
"James Bergstra "
"Pascal Lamblin "
)
"Pascal Lamblin "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
...
@@ -63,16 +63,16 @@ from scan_utils import safe_new, traverse
...
@@ -63,16 +63,16 @@ from scan_utils import safe_new, traverse
_logger
=
logging
.
getLogger
(
'theano.scan_module.scan'
)
_logger
=
logging
.
getLogger
(
'theano.scan_module.scan'
)
def
scan
(
fn
def
scan
(
fn
,
,
sequences
=
None
sequences
=
None
,
,
outputs_info
=
None
outputs_info
=
None
,
,
non_sequences
=
None
non_sequences
=
None
,
,
n_steps
=
None
n_steps
=
None
,
,
truncate_gradient
=
-
1
truncate_gradient
=-
1
,
,
go_backwards
=
False
go_backwards
=
False
,
,
mode
=
None
mode
=
None
,
,
name
=
None
name
=
None
,
,
profile
=
False
):
profile
=
False
):
"""
"""
This function constructs and applies a Scan op to the provided
This function constructs and applies a Scan op to the provided
arguments.
arguments.
...
@@ -333,12 +333,12 @@ def scan( fn
...
@@ -333,12 +333,12 @@ def scan( fn
'''
'''
if
x
is
None
:
if
x
is
None
:
return
[]
return
[]
elif
not
isinstance
(
x
,
(
list
,
tuple
)):
elif
not
isinstance
(
x
,
(
list
,
tuple
)):
return
[
x
]
return
[
x
]
else
:
else
:
return
list
(
x
)
return
list
(
x
)
seqs
=
wrap_into_list
(
sequences
)
seqs
=
wrap_into_list
(
sequences
)
outs_info
=
wrap_into_list
(
outputs_info
)
outs_info
=
wrap_into_list
(
outputs_info
)
# Make sure we get rid of numpy arrays or ints or anything like that
# Make sure we get rid of numpy arrays or ints or anything like that
...
@@ -356,19 +356,19 @@ def scan( fn
...
@@ -356,19 +356,19 @@ def scan( fn
# To do that we check here to see the nature of n_steps
# To do that we check here to see the nature of n_steps
n_fixed_steps
=
None
n_fixed_steps
=
None
if
isinstance
(
n_steps
,
(
float
,
int
)):
if
isinstance
(
n_steps
,
(
float
,
int
)):
n_fixed_steps
=
int
(
n_steps
)
n_fixed_steps
=
int
(
n_steps
)
else
:
else
:
try
:
try
:
n_fixed_steps
=
opt
.
get_constant_value
(
n_steps
)
n_fixed_steps
=
opt
.
get_constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
):
except
(
TypeError
,
AttributeError
):
n_fixed_steps
=
None
n_fixed_steps
=
None
# Check n_steps is an int
# Check n_steps is an int
if
(
hasattr
(
n_steps
,
'dtype'
)
and
if
(
hasattr
(
n_steps
,
'dtype'
)
and
str
(
n_steps
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
):
str
(
n_steps
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
):
raise
ValueError
(
' n_steps must be an int. dtype provided '
raise
ValueError
(
' n_steps must be an int. dtype provided '
'is
%
s'
%
n_steps
.
dtype
)
'is
%
s'
%
n_steps
.
dtype
)
# compute number of sequences and number of outputs
# compute number of sequences and number of outputs
n_seqs
=
len
(
seqs
)
n_seqs
=
len
(
seqs
)
...
@@ -377,11 +377,11 @@ def scan( fn
...
@@ -377,11 +377,11 @@ def scan( fn
return_steps
=
{}
return_steps
=
{}
# wrap sequences in a dictionary if they are not already dictionaries
# wrap sequences in a dictionary if they are not already dictionaries
for
i
in
xrange
(
n_seqs
):
for
i
in
xrange
(
n_seqs
):
if
not
isinstance
(
seqs
[
i
],
dict
)
:
if
not
isinstance
(
seqs
[
i
],
dict
):
seqs
[
i
]
=
dict
(
input
=
seqs
[
i
],
taps
=
[
0
])
seqs
[
i
]
=
dict
(
input
=
seqs
[
i
],
taps
=
[
0
])
elif
seqs
[
i
]
.
get
(
'taps'
,
None
):
elif
seqs
[
i
]
.
get
(
'taps'
,
None
):
seqs
[
i
][
'taps'
]
=
wrap_into_list
(
seqs
[
i
][
'taps'
])
seqs
[
i
][
'taps'
]
=
wrap_into_list
(
seqs
[
i
][
'taps'
])
elif
seqs
[
i
]
.
get
(
'taps'
,
True
)
is
None
:
elif
seqs
[
i
]
.
get
(
'taps'
,
True
)
is
None
:
# seqs dictionary does not have the ``taps`` key
# seqs dictionary does not have the ``taps`` key
seqs
[
i
][
'taps'
]
=
[
0
]
seqs
[
i
][
'taps'
]
=
[
0
]
...
@@ -391,30 +391,31 @@ def scan( fn
...
@@ -391,30 +391,31 @@ def scan( fn
if
isinstance
(
outs_info
[
i
],
dict
):
if
isinstance
(
outs_info
[
i
],
dict
):
# DEPRICATED :
# DEPRICATED :
if
outs_info
[
i
]
.
get
(
'return_steps'
,
None
):
if
outs_info
[
i
]
.
get
(
'return_steps'
,
None
):
_logger
.
warning
(
(
"Using `return_steps` has been depricated."
_logger
.
warning
((
"Using `return_steps` has been "
" Simply select the entries you need using "
"depricated. Simply select the entries you "
" a subtensor. Scan will optimize memory "
"need using a subtensor. Scan will optimize "
" consumption, so do not worry about that."
))
"memory consumption, so do not worry about "
"that."
))
return_steps
[
i
]
=
outs_info
[
i
][
'return_steps'
]
return_steps
[
i
]
=
outs_info
[
i
][
'return_steps'
]
# END
# END
if
not
isinstance
(
outs_info
[
i
],
dict
):
if
not
isinstance
(
outs_info
[
i
],
dict
):
# by default any output has a tap value of -1
# by default any output has a tap value of -1
outs_info
[
i
]
=
dict
(
initial
=
outs_info
[
i
],
taps
=
[
-
1
])
outs_info
[
i
]
=
dict
(
initial
=
outs_info
[
i
],
taps
=
[
-
1
])
elif
(
not
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
elif
(
not
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
# ^ no initial state but taps provided
# ^ no initial state but taps provided
raise
ValueError
(
(
'If you are using slices of an output '
raise
ValueError
(
(
'If you are using slices of an output '
'you need to provide a initial state '
'you need to provide a initial state '
'for it'
),
outs_info
[
i
]
)
'for it'
),
outs_info
[
i
]
)
elif
(
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
elif
(
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
not
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
not
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
# ^ initial state but taps not provided
# ^ initial state but taps not provided
if
outs_info
[
i
]
.
has_key
(
'taps'
)
:
if
'taps'
in
outs_info
[
i
]
:
# ^ explicitly provided a None for taps
# ^ explicitly provided a None for taps
_logger
.
warning
(
'Output
%
s ( index
%
d) has a initial
state
'
_logger
.
warning
(
'Output
%
s ( index
%
d) has a initial '
'
but taps is explicitly set to None '
,
'state
but taps is explicitly set to None '
,
getattr
(
outs_info
[
i
][
'initial'
],
'name'
,
'None'
),
getattr
(
outs_info
[
i
][
'initial'
],
'name'
,
'None'
),
i
)
i
)
outs_info
[
i
][
'taps'
]
=
[
-
1
]
outs_info
[
i
][
'taps'
]
=
[
-
1
]
else
:
else
:
...
@@ -434,12 +435,12 @@ def scan( fn
...
@@ -434,12 +435,12 @@ def scan( fn
# and to construct a new and complete list of inputs and
# and to construct a new and complete list of inputs and
# outputs
# outputs
n_seqs
=
0
n_seqs
=
0
scan_seqs
=
[]
# Variables passed as inputs to the scan op
scan_seqs
=
[]
# Variables passed as inputs to the scan op
inner_seqs
=
[]
# Variables passed as inputs to the inner function
inner_seqs
=
[]
# Variables passed as inputs to the inner function
inner_slices
=
[]
# Actual slices if scan is removed from the picture
inner_slices
=
[]
# Actual slices if scan is removed from the picture
# go through sequences picking up time slices as needed
# go through sequences picking up time slices as needed
for
i
,
seq
in
enumerate
(
seqs
):
for
i
,
seq
in
enumerate
(
seqs
):
# Note that you can have something like no taps for
# Note that you can have something like no taps for
# a sequence, though is highly unlikely in practice
# a sequence, though is highly unlikely in practice
if
'taps'
in
seq
:
if
'taps'
in
seq
:
...
@@ -456,31 +457,33 @@ def scan( fn
...
@@ -456,31 +457,33 @@ def scan( fn
# If not we need to use copies, that will be replaced at
# If not we need to use copies, that will be replaced at
# each frame by the corresponding slice
# each frame by the corresponding slice
actual_slice
=
seq
[
'input'
][
k
-
mintap
]
actual_slice
=
seq
[
'input'
][
k
-
mintap
]
_seq_val
=
tensor
.
as_tensor_variable
(
seq
[
'input'
])
_seq_val
=
tensor
.
as_tensor_variable
(
seq
[
'input'
])
_seq_val_slice
=
_seq_val
[
k
-
mintap
]
_seq_val_slice
=
_seq_val
[
k
-
mintap
]
nw_slice
=
_seq_val_slice
.
type
()
nw_slice
=
_seq_val_slice
.
type
()
# Try to transfer test_value to the new variable
# Try to transfer test_value to the new variable
if
config
.
compute_test_value
!=
'off'
:
if
config
.
compute_test_value
!=
'off'
:
try
:
try
:
nw_slice
.
tag
.
test_value
=
gof
.
Op
.
_get_test_value
(
_seq_val_slice
)
nw_slice
.
tag
.
test_value
=
gof
.
Op
.
_get_test_value
(
_seq_val_slice
)
except
AttributeError
,
e
:
except
AttributeError
,
e
:
if
config
.
compute_test_value
!=
'ignore'
:
if
config
.
compute_test_value
!=
'ignore'
:
# No need to print a warning or raise an error now,
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
# it will be done when fn will be called.
_logger
.
info
((
'Cannot compute test value for the inner '
_logger
.
info
((
'Cannot compute test value for '
'function of scan, input value missing
%
s'
),
e
)
'the inner function of scan, input value '
'missing
%
s'
),
e
)
# Add names to slices for debugging and pretty printing ..
# Add names to slices for debugging and pretty printing ..
# that is if the input already has a name
# that is if the input already has a name
if
getattr
(
seq
[
'input'
],
'name'
,
None
)
is
not
None
:
if
getattr
(
seq
[
'input'
],
'name'
,
None
)
is
not
None
:
if
k
>
0
:
if
k
>
0
:
nw_name
=
seq
[
'input'
]
.
name
+
'[t+
%
d]'
%
k
nw_name
=
seq
[
'input'
]
.
name
+
'[t+
%
d]'
%
k
elif
k
==
0
:
elif
k
==
0
:
nw_name
=
seq
[
'input'
]
.
name
+
'[t]'
nw_name
=
seq
[
'input'
]
.
name
+
'[t]'
else
:
else
:
nw_name
=
seq
[
'input'
]
.
name
+
'[t
%
d]'
%
k
nw_name
=
seq
[
'input'
]
.
name
+
'[t
%
d]'
%
k
nw_slice
.
name
=
nw_name
nw_slice
.
name
=
nw_name
# We cut the sequence such that seq[i] to correspond to
# We cut the sequence such that seq[i] to correspond to
...
@@ -490,34 +493,30 @@ def scan( fn
...
@@ -490,34 +493,30 @@ def scan( fn
else
:
else
:
offset
=
0
offset
=
0
if
maxtap
==
mintap
and
maxtap
!=
0
:
if
maxtap
==
mintap
and
maxtap
!=
0
:
nw_seq
=
seq
[
'input'
][:
abs
(
maxtap
)]
nw_seq
=
seq
[
'input'
][:
abs
(
maxtap
)]
elif
maxtap
-
k
!=
0
:
elif
maxtap
-
k
!=
0
:
nw_seq
=
seq
[
'input'
][
offset
+
k
-
mintap
:
-
(
maxtap
-
k
)]
nw_seq
=
seq
[
'input'
][
offset
+
k
-
mintap
:
-
(
maxtap
-
k
)]
else
:
else
:
nw_seq
=
seq
[
'input'
][
offset
+
k
-
mintap
:
]
nw_seq
=
seq
[
'input'
][
offset
+
k
-
mintap
:
]
if
go_backwards
:
if
go_backwards
:
nw_seq
=
nw_seq
[::
-
1
]
nw_seq
=
nw_seq
[::
-
1
]
scan_seqs
.
append
(
nw_seq
)
inner_seqs
.
append
(
nw_slice
)
scan_seqs
.
append
(
nw_seq
)
inner_slices
.
append
(
actual_slice
)
inner_seqs
.
append
(
nw_slice
)
inner_slices
.
append
(
actual_slice
)
n_seqs
+=
1
n_seqs
+=
1
# Since we've added all sequences now we need to level them up based on
# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
# n_steps or their different shapes
lengths_vec
=
[]
lengths_vec
=
[]
for
seq
in
scan_seqs
:
for
seq
in
scan_seqs
:
lengths_vec
.
append
(
seq
.
shape
[
0
]
)
lengths_vec
.
append
(
seq
.
shape
[
0
]
)
if
not
scan_utils
.
isNaN_or_Inf_or_None
(
n_steps
):
if
not
scan_utils
.
isNaN_or_Inf_or_None
(
n_steps
):
# ^ N_steps should also be considered
# ^ N_steps should also be considered
lengths_vec
.
append
(
tensor
.
as_tensor
(
n_steps
)
)
lengths_vec
.
append
(
tensor
.
as_tensor
(
n_steps
))
if
len
(
lengths_vec
)
==
0
:
if
len
(
lengths_vec
)
==
0
:
# ^ No information about the number of steps
# ^ No information about the number of steps
raise
ValueError
(
' No information about the number of steps '
raise
ValueError
(
' No information about the number of steps '
'provided. Either provide a value for '
'provided. Either provide a value for '
...
@@ -534,10 +533,12 @@ def scan( fn
...
@@ -534,10 +533,12 @@ def scan( fn
actual_n_steps
=
tensor
.
as_tensor
(
n_steps
)
actual_n_steps
=
tensor
.
as_tensor
(
n_steps
)
# Add names -- it helps a lot when debugging
# Add names -- it helps a lot when debugging
for
(
nw_seq
,
seq
)
in
zip
(
scan_seqs
,
seqs
):
for
(
nw_seq
,
seq
)
in
zip
(
scan_seqs
,
seqs
):
if
getattr
(
seq
[
'input'
],
'name'
,
None
)
is
not
None
:
if
getattr
(
seq
[
'input'
],
'name'
,
None
)
is
not
None
:
nw_seq
.
name
=
seq
[
'input'
]
.
name
+
'[
%
d:]'
%
k
nw_seq
.
name
=
seq
[
'input'
]
.
name
+
'[
%
d:]'
%
k
scan_seqs
=
[
seq
[:
actual_n_steps
]
for
seq
in
scan_seqs
]
# Conventions :
# Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided
# mit_mot = multiple input taps, multiple output taps ( only provided
# by the gradient function )
# by the gradient function )
...
@@ -545,39 +546,35 @@ def scan( fn
...
@@ -545,39 +546,35 @@ def scan( fn
# sit_sot = single input tap, single output tap (t + 0)
# sit_sot = single input tap, single output tap (t + 0)
# nit_sot = no input tap, single output tap (t + 0)
# nit_sot = no input tap, single output tap (t + 0)
# MIT_MOT -- not provided by the user only by the grad function
# MIT_MOT -- not provided by the user only by the grad function
n_mit_mot
=
0
n_mit_mot
=
0
n_mit_mot_outs
=
0
n_mit_mot_outs
=
0
mit_mot_scan_inputs
=
[]
mit_mot_scan_inputs
=
[]
mit_mot_inner_inputs
=
[]
mit_mot_inner_inputs
=
[]
mit_mot_inner_outputs
=
[]
mit_mot_inner_outputs
=
[]
mit_mot_out_slices
=
[]
mit_mot_out_slices
=
[]
mit_mot_rightOrder
=
[]
mit_mot_rightOrder
=
[]
# SIT_SOT -- provided by the user
# SIT_SOT -- provided by the user
n_mit_sot
=
0
n_mit_sot
=
0
mit_sot_scan_inputs
=
[]
mit_sot_scan_inputs
=
[]
mit_sot_inner_inputs
=
[]
mit_sot_inner_inputs
=
[]
mit_sot_inner_slices
=
[]
mit_sot_inner_slices
=
[]
mit_sot_inner_outputs
=
[]
mit_sot_inner_outputs
=
[]
mit_sot_return_steps
=
{}
mit_sot_return_steps
=
{}
mit_sot_tap_array
=
[]
mit_sot_tap_array
=
[]
mit_sot_rightOrder
=
[]
mit_sot_rightOrder
=
[]
n_sit_sot
=
0
n_sit_sot
=
0
sit_sot_scan_inputs
=
[]
sit_sot_scan_inputs
=
[]
sit_sot_inner_inputs
=
[]
sit_sot_inner_inputs
=
[]
sit_sot_inner_slices
=
[]
sit_sot_inner_slices
=
[]
sit_sot_inner_outputs
=
[]
sit_sot_inner_outputs
=
[]
sit_sot_return_steps
=
{}
sit_sot_return_steps
=
{}
sit_sot_rightOrder
=
[]
sit_sot_rightOrder
=
[]
# go through outputs picking up time slices as needed
# go through outputs picking up time slices as needed
for
i
,
init_out
in
enumerate
(
outs_info
):
for
i
,
init_out
in
enumerate
(
outs_info
):
# Note that our convention dictates that if an output uses
# Note that our convention dictates that if an output uses
# just the previous time step, as a initial state we will only
# just the previous time step, as a initial state we will only
# provide a tensor of the same dimension as one time step; This
# provide a tensor of the same dimension as one time step; This
...
@@ -601,11 +598,12 @@ def scan( fn
...
@@ -601,11 +598,12 @@ def scan( fn
if
config
.
compute_test_value
!=
'ignore'
:
if
config
.
compute_test_value
!=
'ignore'
:
# No need to print a warning or raise an error now,
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
# it will be done when fn will be called.
_logger
.
info
((
'Cannot compute test value for the inner '
_logger
.
info
((
'Cannot compute test value for the '
'function of scan, input value missing
%
s'
),
e
)
'inner function of scan, input value missing
%
s'
),
e
)
if
getattr
(
init_out
[
'initial'
],
'name'
,
None
)
is
not
None
:
if
getattr
(
init_out
[
'initial'
],
'name'
,
None
)
is
not
None
:
arg
.
name
=
init_out
[
'initial'
]
.
name
+
'[t-1]'
arg
.
name
=
init_out
[
'initial'
]
.
name
+
'[t-1]'
# We need now to allocate space for storing the output and copy
# We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function
# the initial state over. We do this using the expand function
...
@@ -613,117 +611,119 @@ def scan( fn
...
@@ -613,117 +611,119 @@ def scan( fn
sit_sot_scan_inputs
.
append
(
sit_sot_scan_inputs
.
append
(
scan_utils
.
expand
(
scan_utils
.
expand
(
tensor
.
unbroadcast
(
tensor
.
unbroadcast
(
tensor
.
shape_padleft
(
actual_arg
),
0
)
tensor
.
shape_padleft
(
actual_arg
),
0
)
,
,
actual_n_steps
actual_n_steps
)
)
))
sit_sot_inner_slices
.
append
(
actual_arg
)
sit_sot_inner_slices
.
append
(
actual_arg
)
if
i
in
return_steps
:
if
i
in
return_steps
:
sit_sot_return_steps
[
n_sit_sot
]
=
return_steps
[
i
]
sit_sot_return_steps
[
n_sit_sot
]
=
return_steps
[
i
]
sit_sot_inner_inputs
.
append
(
arg
)
sit_sot_inner_inputs
.
append
(
arg
)
sit_sot_rightOrder
.
append
(
i
)
sit_sot_rightOrder
.
append
(
i
)
n_sit_sot
+=
1
n_sit_sot
+=
1
elif
init_out
.
get
(
'taps'
,
None
):
elif
init_out
.
get
(
'taps'
,
None
):
if
numpy
.
any
(
numpy
.
array
(
init_out
.
get
(
'taps'
,[]))
>
0
):
if
numpy
.
any
(
numpy
.
array
(
init_out
.
get
(
'taps'
,
[]))
>
0
):
# Make sure we do not have requests for future values of a
# Make sure we do not have requests for future values of a
# sequence we can not provide such values
# sequence we can not provide such values
raise
ValueError
(
'Can not use future taps of outputs'
raise
ValueError
(
'Can not use future taps of outputs'
,
,
init_out
)
init_out
)
# go through the taps
# go through the taps
mintap
=
abs
(
numpy
.
min
(
init_out
[
'taps'
]))
mintap
=
abs
(
numpy
.
min
(
init_out
[
'taps'
]))
mit_sot_tap_array
.
append
(
init_out
[
'taps'
]
)
mit_sot_tap_array
.
append
(
init_out
[
'taps'
]
)
idx_offset
=
abs
(
numpy
.
min
(
init_out
[
'taps'
]))
idx_offset
=
abs
(
numpy
.
min
(
init_out
[
'taps'
]))
# Sequence
# Sequence
mit_sot_scan_inputs
.
append
(
mit_sot_scan_inputs
.
append
(
scan_utils
.
expand
(
init_out
[
'initial'
][:
mintap
]
scan_utils
.
expand
(
init_out
[
'initial'
][:
mintap
],
,
actual_n_steps
)
)
actual_n_steps
)
)
if
i
in
return_steps
:
if
i
in
return_steps
:
mit_sot_return_steps
[
n_mit_sot
]
=
return_steps
[
i
]
mit_sot_return_steps
[
n_mit_sot
]
=
return_steps
[
i
]
mit_sot_rightOrder
.
append
(
i
)
mit_sot_rightOrder
.
append
(
i
)
n_mit_sot
+=
1
n_mit_sot
+=
1
for
k
in
init_out
[
'taps'
]:
for
k
in
init_out
[
'taps'
]:
# create a new slice
# create a new slice
actual_nw_slice
=
init_out
[
'initial'
][
k
+
mintap
]
actual_nw_slice
=
init_out
[
'initial'
][
k
+
mintap
]
_init_out_var
=
tensor
.
as_tensor_variable
(
init_out
[
'initial'
])
_init_out_var
=
tensor
.
as_tensor_variable
(
init_out
[
'initial'
])
_init_out_var_slice
=
_init_out_var
[
k
+
mintap
]
_init_out_var_slice
=
_init_out_var
[
k
+
mintap
]
nw_slice
=
_init_out_var_slice
.
type
()
nw_slice
=
_init_out_var_slice
.
type
()
# Try to transfer test_value to the new variable
# Try to transfer test_value to the new variable
if
config
.
compute_test_value
!=
'off'
:
if
config
.
compute_test_value
!=
'off'
:
try
:
try
:
nw_slice
.
tag
.
test_value
=
gof
.
Op
.
_get_test_value
(
_init_out_var_slice
)
nw_slice
.
tag
.
test_value
=
gof
.
Op
.
_get_test_value
(
_init_out_var_slice
)
except
AttributeError
,
e
:
except
AttributeError
,
e
:
if
config
.
compute_test_value
!=
'ignore'
:
if
config
.
compute_test_value
!=
'ignore'
:
# No need to print a warning or raise an error now,
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
# it will be done when fn will be called.
_logger
.
info
((
'Cannot compute test value for the inner '
_logger
.
info
((
'Cannot compute test value for '
'function of scan, input value missing.
%
s'
),
e
)
'the inner function of scan, input value '
'missing.
%
s'
),
e
)
# give it a name or debugging and pretty printing
# give it a name or debugging and pretty printing
if
getattr
(
init_out
[
'initial'
],
'name'
,
None
)
is
not
None
:
if
getattr
(
init_out
[
'initial'
],
'name'
,
None
)
is
not
None
:
if
k
>
0
:
if
k
>
0
:
nw_slice
.
name
=
(
init_out
[
'initial'
]
.
name
+
nw_slice
.
name
=
(
init_out
[
'initial'
]
.
name
+
'[t+
%
d]'
%
k
)
'[t+
%
d]'
%
k
)
elif
k
==
0
:
elif
k
==
0
:
nw_slice
.
name
=
init_out
[
'initial'
]
.
name
+
'[t]'
nw_slice
.
name
=
init_out
[
'initial'
]
.
name
+
'[t]'
else
:
else
:
nw_slice
.
name
=
(
init_out
[
'initial'
]
.
name
+
nw_slice
.
name
=
(
init_out
[
'initial'
]
.
name
+
'[t
%
d]'
%
k
)
'[t
%
d]'
%
k
)
mit_sot_inner_inputs
.
append
(
nw_slice
)
mit_sot_inner_inputs
.
append
(
nw_slice
)
mit_sot_inner_slices
.
append
(
actual_nw_slice
)
mit_sot_inner_slices
.
append
(
actual_nw_slice
)
#NOTE: there is another case, in which we do not want to provide
#NOTE: there is another case, in which we do not want to provide
# any previous value of the output to the inner function (i.e.
# any previous value of the output to the inner function (i.e.
# a map); in that case we do not have to do anything ..
# a map); in that case we do not have to do anything ..
# Re-order args
# Re-order args
max_mit_sot
=
numpy
.
max
(
[
-
1
]
+
mit_sot_rightOrder
)
+
1
max_mit_sot
=
numpy
.
max
(
[
-
1
]
+
mit_sot_rightOrder
)
+
1
max_sit_sot
=
numpy
.
max
(
[
-
1
]
+
sit_sot_rightOrder
)
+
1
max_sit_sot
=
numpy
.
max
(
[
-
1
]
+
sit_sot_rightOrder
)
+
1
n_elems
=
numpy
.
max
(
[
max_mit_sot
,
max_sit_sot
]
)
n_elems
=
numpy
.
max
([
max_mit_sot
,
max_sit_sot
]
)
_ordered_args
=
[[]
for
x
in
xrange
(
n_elems
)]
_ordered_args
=
[[]
for
x
in
xrange
(
n_elems
)]
offset
=
0
offset
=
0
for
idx
in
xrange
(
n_mit_sot
):
for
idx
in
xrange
(
n_mit_sot
):
n_inputs
=
len
(
mit_sot_tap_array
[
idx
])
n_inputs
=
len
(
mit_sot_tap_array
[
idx
])
if
n_fixed_steps
in
[
1
,
-
1
]:
if
n_fixed_steps
in
[
1
,
-
1
]:
_ordered_args
[
mit_sot_rightOrder
[
idx
]]
=
\
_ordered_args
[
mit_sot_rightOrder
[
idx
]]
=
\
mit_sot_inner_slices
[
offset
:
offset
+
n_inputs
]
mit_sot_inner_slices
[
offset
:
offset
+
n_inputs
]
else
:
else
:
_ordered_args
[
mit_sot_rightOrder
[
idx
]]
=
\
_ordered_args
[
mit_sot_rightOrder
[
idx
]]
=
\
mit_sot_inner_inputs
[
offset
:
offset
+
n_inputs
]
mit_sot_inner_inputs
[
offset
:
offset
+
n_inputs
]
offset
+=
n_inputs
offset
+=
n_inputs
for
idx
in
xrange
(
n_sit_sot
):
for
idx
in
xrange
(
n_sit_sot
):
if
n_fixed_steps
in
[
1
,
-
1
]:
if
n_fixed_steps
in
[
1
,
-
1
]:
_ordered_args
[
sit_sot_rightOrder
[
idx
]]
=
\
_ordered_args
[
sit_sot_rightOrder
[
idx
]]
=
\
[
sit_sot_inner_slices
[
idx
]
]
[
sit_sot_inner_slices
[
idx
]
]
else
:
else
:
_ordered_args
[
sit_sot_rightOrder
[
idx
]]
=
\
_ordered_args
[
sit_sot_rightOrder
[
idx
]]
=
\
[
sit_sot_inner_inputs
[
idx
]
]
[
sit_sot_inner_inputs
[
idx
]
]
ordered_args
=
[]
ordered_args
=
[]
for
ls
in
_ordered_args
:
for
ls
in
_ordered_args
:
ordered_args
+=
ls
ordered_args
+=
ls
if
n_fixed_steps
in
[
1
,
-
1
]:
if
n_fixed_steps
in
[
1
,
-
1
]:
args
=
(
inner_slices
+
args
=
(
inner_slices
+
ordered_args
+
ordered_args
+
non_seqs
)
non_seqs
)
else
:
else
:
args
=
(
inner_seqs
+
args
=
(
inner_seqs
+
ordered_args
+
ordered_args
+
non_seqs
)
non_seqs
)
# add only the non-shared variables and non-constants to the arguments of the dummy
# add only the non-shared variables and non-constants to the arguments of
# function [ a function should not get shared variables or constants as input ]
# the dummy function [ a function should not get shared variables or
# constants as input ]
dummy_args
=
[
arg
for
arg
in
args
dummy_args
=
[
arg
for
arg
in
args
if
(
not
isinstance
(
arg
,
SharedVariable
)
and
if
(
not
isinstance
(
arg
,
SharedVariable
)
and
not
isinstance
(
arg
,
tensor
.
Constant
)
)]
not
isinstance
(
arg
,
tensor
.
Constant
))]
# when we apply the lambda expression we get a mixture of update rules
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
# and outputs that needs to be separated
condition
,
outputs
,
updates
=
scan_utils
.
get_updates_and_outputs
(
fn
(
*
args
))
condition
,
outputs
,
updates
=
scan_utils
.
get_updates_and_outputs
(
fn
(
*
args
))
if
condition
is
not
None
:
if
condition
is
not
None
:
as_while
=
True
as_while
=
True
...
@@ -733,14 +733,13 @@ def scan( fn
...
@@ -733,14 +733,13 @@ def scan( fn
### Step 3. Check if we actually need scan and remove it if we don't
### Step 3. Check if we actually need scan and remove it if we don't
##
##
if
n_fixed_steps
in
[
1
,
-
1
]:
if
n_fixed_steps
in
[
1
,
-
1
]:
# We do not need to use the scan op anymore, so we can just return
# We do not need to use the scan op anymore, so we can just return
# the outputs and updates we have
# the outputs and updates we have
if
condition
is
not
None
:
if
condition
is
not
None
:
_logger
.
warning
(
(
'When the number of steps is fixed and equal to 1,
'
_logger
.
warning
(
(
'When the number of steps is fixed and equal
'
' the provided stopping condition, '
,
str
(
condition
)
,
'to 1, the provided stopping condition, '
,
' is ignored'
))
str
(
condition
),
' is ignored'
))
for
pos
,
inner_out
in
enumerate
(
outputs
):
for
pos
,
inner_out
in
enumerate
(
outputs
):
# we need to see if we need to pad our sequences with an
# we need to see if we need to pad our sequences with an
...
@@ -749,16 +748,15 @@ def scan( fn
...
@@ -749,16 +748,15 @@ def scan( fn
# then, if we return the output as given by the innner function
# then, if we return the output as given by the innner function
# this will represent only a slice and it will have one
# this will represent only a slice and it will have one
# dimension less.
# dimension less.
if
(
isinstance
(
inner_out
.
type
,
tensor
.
TensorType
)
and
if
(
isinstance
(
inner_out
.
type
,
tensor
.
TensorType
)
and
return_steps
.
get
(
pos
,
0
)
!=
1
):
return_steps
.
get
(
pos
,
0
)
!=
1
):
outputs
[
pos
]
=
tensor
.
unbroadcast
(
outputs
[
pos
]
=
tensor
.
unbroadcast
(
tensor
.
shape_padleft
(
inner_out
),
0
)
tensor
.
shape_padleft
(
inner_out
),
0
)
if
len
(
outputs
)
==
1
:
if
len
(
outputs
)
==
1
:
outputs
=
outputs
[
0
]
outputs
=
outputs
[
0
]
return
(
outputs
,
updates
)
return
(
outputs
,
updates
)
##
##
### Step 4. Compile the dummy function
### Step 4. Compile the dummy function
##
##
...
@@ -778,11 +776,11 @@ def scan( fn
...
@@ -778,11 +776,11 @@ def scan( fn
replace
=
dict
(
zip
(
non_seqs
,
replace
=
dict
(
zip
(
non_seqs
,
fake_nonseqs
)))
fake_nonseqs
)))
all_inputs
=
itertools
.
ifilter
(
all_inputs
=
itertools
.
ifilter
(
lambda
x
:
(
isinstance
(
x
,
gof
.
Variable
)
and
lambda
x
:
(
isinstance
(
x
,
gof
.
Variable
)
and
not
isinstance
(
x
,
SharedVariable
)
and
not
isinstance
(
x
,
SharedVariable
)
and
not
isinstance
(
x
,
gof
.
Constant
)
),
not
isinstance
(
x
,
gof
.
Constant
)),
gof
.
graph
.
inputs
(
fake_outputs
)
)
gof
.
graph
.
inputs
(
fake_outputs
)
)
extra_inputs
=
filter
(
lambda
x
:
x
not
in
args
+
fake_nonseqs
,
extra_inputs
=
filter
(
lambda
x
:
x
not
in
args
+
fake_nonseqs
,
all_inputs
)
all_inputs
)
non_seqs
+=
extra_inputs
non_seqs
+=
extra_inputs
## Note we do not use all_inputs directly since the order of variables
## Note we do not use all_inputs directly since the order of variables
...
@@ -792,12 +790,11 @@ def scan( fn
...
@@ -792,12 +790,11 @@ def scan( fn
dummy_outs
=
outputs
dummy_outs
=
outputs
if
condition
is
not
None
:
if
condition
is
not
None
:
dummy_outs
.
append
(
condition
)
dummy_outs
.
append
(
condition
)
dummy_f
=
function
(
dummy_args
dummy_f
=
function
(
dummy_args
,
,
dummy_outs
dummy_outs
,
,
updates
=
updates
updates
=
updates
,
,
mode
=
compile
.
mode
.
Mode
(
linker
=
'py'
,
mode
=
compile
.
mode
.
Mode
(
linker
=
'py'
,
optimizer
=
None
)
)
optimizer
=
None
))
##
##
### Step 5. Re-arange inputs of scan into a more strict order
### Step 5. Re-arange inputs of scan into a more strict order
...
@@ -806,7 +803,6 @@ def scan( fn
...
@@ -806,7 +803,6 @@ def scan( fn
## Step 5.0 Check the outputs of the dummy function to see if they
## Step 5.0 Check the outputs of the dummy function to see if they
## match with user provided data
## match with user provided data
# if the number of outputs to the function does not match the number of
# if the number of outputs to the function does not match the number of
# assumed outputs until now (provided by the user) there can be
# assumed outputs until now (provided by the user) there can be
# only one explanation: No information is provided for any of the
# only one explanation: No information is provided for any of the
...
@@ -814,7 +810,7 @@ def scan( fn
...
@@ -814,7 +810,7 @@ def scan( fn
tmp_dummy_f_outs
=
len
(
dummy_f
.
maker
.
outputs
)
tmp_dummy_f_outs
=
len
(
dummy_f
.
maker
.
outputs
)
if
as_while
:
if
as_while
:
tmp_dummy_f_outs
-=
1
tmp_dummy_f_outs
-=
1
if
not
(
tmp_dummy_f_outs
==
n_outs
or
outs_info
==
[]):
if
not
(
tmp_dummy_f_outs
==
n_outs
or
outs_info
==
[]):
raise
ValueError
(
'Please provide None as output_info for '
raise
ValueError
(
'Please provide None as output_info for '
'any output that does not feed back into '
'any output that does not feed back into '
'scan (i.e. it behaves like a map) '
)
'scan (i.e. it behaves like a map) '
)
...
@@ -823,95 +819,91 @@ def scan( fn
...
@@ -823,95 +819,91 @@ def scan( fn
n_outs
=
len
(
dummy_f
.
maker
.
outputs
)
n_outs
=
len
(
dummy_f
.
maker
.
outputs
)
if
as_while
:
if
as_while
:
n_outs
=
n_outs
-
1
n_outs
=
n_outs
-
1
outs_info
=
[
dict
()
for
x
in
xrange
(
n_outs
)
]
outs_info
=
[
dict
()
for
x
in
xrange
(
n_outs
)]
## Step 5.1 Outputs with taps different then -1
## Step 5.1 Outputs with taps different then -1
for
i
,
out
in
enumerate
(
outs_info
):
for
i
,
out
in
enumerate
(
outs_info
):
if
'taps'
in
out
and
out
[
'taps'
]
!=
[
-
1
]:
if
'taps'
in
out
and
out
[
'taps'
]
!=
[
-
1
]:
mit_sot_inner_outputs
.
append
(
outputs
[
i
])
mit_sot_inner_outputs
.
append
(
outputs
[
i
])
## Step 5.2 Outputs with tap equal to -1
## Step 5.2 Outputs with tap equal to -1
for
i
,
out
in
enumerate
(
outs_info
):
for
i
,
out
in
enumerate
(
outs_info
):
if
'taps'
in
out
and
out
[
'taps'
]
==
[
-
1
]:
if
'taps'
in
out
and
out
[
'taps'
]
==
[
-
1
]:
sit_sot_inner_outputs
.
append
(
outputs
[
i
]
)
sit_sot_inner_outputs
.
append
(
outputs
[
i
])
## Step 5.3 Outputs that correspond to update rules of shared variables
## Step 5.3 Outputs that correspond to update rules of shared variables
givens
=
{}
givens
=
{}
n_shared_outs
=
0
n_shared_outs
=
0
shared_scan_inputs
=
[]
shared_scan_inputs
=
[]
shared_inner_inputs
=
[]
shared_inner_inputs
=
[]
shared_inner_outputs
=
[]
shared_inner_outputs
=
[]
for
input
in
dummy_f
.
maker
.
expanded_inputs
:
for
input
in
dummy_f
.
maker
.
expanded_inputs
:
if
isinstance
(
input
.
variable
,
SharedVariable
)
and
input
.
update
:
if
isinstance
(
input
.
variable
,
SharedVariable
)
and
input
.
update
:
new_var
=
safe_new
(
input
.
variable
)
new_var
=
safe_new
(
input
.
variable
)
if
getattr
(
input
.
variable
,
'name'
,
None
)
is
not
None
:
if
getattr
(
input
.
variable
,
'name'
,
None
)
is
not
None
:
new_var
.
name
=
input
.
variable
.
name
+
'_copy'
new_var
.
name
=
input
.
variable
.
name
+
'_copy'
shared_inner_inputs
.
append
(
new_var
)
shared_inner_inputs
.
append
(
new_var
)
shared_scan_inputs
.
append
(
input
.
variable
)
shared_scan_inputs
.
append
(
input
.
variable
)
shared_inner_outputs
.
append
(
input
.
update
)
shared_inner_outputs
.
append
(
input
.
update
)
givens
[
input
.
variable
]
=
new_var
givens
[
input
.
variable
]
=
new_var
n_shared_outs
+=
1
n_shared_outs
+=
1
## Step 5.4 Outputs with no taps used in the input
## Step 5.4 Outputs with no taps used in the input
n_nit_sot
=
0
n_nit_sot
=
0
nit_sot_inner_outputs
=
[]
nit_sot_inner_outputs
=
[]
nit_sot_return_steps
=
{}
nit_sot_return_steps
=
{}
nit_sot_rightOrder
=
[]
nit_sot_rightOrder
=
[]
for
i
,
out
in
enumerate
(
outs_info
):
for
i
,
out
in
enumerate
(
outs_info
):
if
not
'taps'
in
out
:
if
not
'taps'
in
out
:
nit_sot_inner_outputs
.
append
(
outputs
[
i
]
)
nit_sot_inner_outputs
.
append
(
outputs
[
i
]
)
if
i
in
return_steps
:
if
i
in
return_steps
:
nit_sot_return_steps
[
n_nit_sot
]
=
return_steps
[
i
]
nit_sot_return_steps
[
n_nit_sot
]
=
return_steps
[
i
]
nit_sot_rightOrder
.
append
(
i
)
nit_sot_rightOrder
.
append
(
i
)
n_nit_sot
+=
1
n_nit_sot
+=
1
## Step 5.5 all other arguments including extra inputs
## Step 5.5 all other arguments including extra inputs
other_scan_args
=
[]
other_scan_args
=
[]
other_inner_args
=
[]
other_inner_args
=
[]
other_scan_args
+=
[
arg
for
arg
in
non_seqs
other_scan_args
+=
[
arg
for
arg
in
non_seqs
if
(
not
isinstance
(
arg
,
SharedVariable
)
and
if
(
not
isinstance
(
arg
,
SharedVariable
)
and
not
isinstance
(
arg
,
tensor
.
Constant
))]
not
isinstance
(
arg
,
tensor
.
Constant
))]
## Step 5.6 all shared variables with no update rules
## Step 5.6 all shared variables with no update rules
other_inner_args
+=
[
safe_new
(
arg
,
'_copy'
)
for
arg
in
non_seqs
other_inner_args
+=
[
safe_new
(
arg
,
'_copy'
)
for
arg
in
non_seqs
if
(
not
isinstance
(
arg
,
SharedVariable
)
and
if
(
not
isinstance
(
arg
,
SharedVariable
)
and
not
isinstance
(
arg
,
tensor
.
Constant
))]
not
isinstance
(
arg
,
tensor
.
Constant
))]
givens
.
update
(
dict
(
zip
(
other_scan_args
,
other_inner_args
)
))
givens
.
update
(
dict
(
zip
(
other_scan_args
,
other_inner_args
)
))
other_shared_scan_args
=
[
arg
.
variable
for
arg
other_shared_scan_args
=
[
arg
.
variable
for
arg
in
dummy_f
.
maker
.
expanded_inputs
in
dummy_f
.
maker
.
expanded_inputs
if
(
isinstance
(
arg
.
variable
,
SharedVariable
)
and
if
(
isinstance
(
arg
.
variable
,
SharedVariable
)
and
not
arg
.
update
)
]
not
arg
.
update
)]
other_shared_inner_args
=
[
safe_new
(
arg
.
variable
,
'_copy'
)
for
arg
other_shared_inner_args
=
[
safe_new
(
arg
.
variable
,
'_copy'
)
for
arg
in
dummy_f
.
maker
.
expanded_inputs
in
dummy_f
.
maker
.
expanded_inputs
if
(
isinstance
(
arg
.
variable
,
SharedVariable
)
and
if
(
isinstance
(
arg
.
variable
,
SharedVariable
)
and
not
arg
.
update
)
]
not
arg
.
update
)]
givens
.
update
(
dict
(
zip
(
other_shared_scan_args
,
givens
.
update
(
dict
(
zip
(
other_shared_scan_args
,
other_shared_inner_args
)
)
)
other_shared_inner_args
)))
##
##
### Step 6. Re-order the outputs and clone them replacing things
### Step 6. Re-order the outputs and clone them replacing things
### using the givens
### using the givens
##
##
inner_inputs
=
(
inner_seqs
+
inner_inputs
=
(
inner_seqs
+
mit_mot_inner_inputs
+
mit_mot_inner_inputs
+
mit_sot_inner_inputs
+
mit_sot_inner_inputs
+
sit_sot_inner_inputs
+
sit_sot_inner_inputs
+
shared_inner_inputs
+
shared_inner_inputs
+
other_shared_inner_args
+
other_shared_inner_args
+
other_inner_args
)
other_inner_args
)
inner_outs
=
(
mit_mot_inner_outputs
+
inner_outs
=
(
mit_mot_inner_outputs
+
mit_sot_inner_outputs
+
mit_sot_inner_outputs
+
sit_sot_inner_outputs
+
sit_sot_inner_outputs
+
nit_sot_inner_outputs
+
nit_sot_inner_outputs
+
shared_inner_outputs
)
shared_inner_outputs
)
if
condition
is
not
None
:
if
condition
is
not
None
:
inner_outs
.
append
(
condition
)
inner_outs
.
append
(
condition
)
# Cuda is imported here, instead of being imported on top of the file
# Cuda is imported here, instead of being imported on top of the file
...
@@ -926,59 +918,58 @@ def scan( fn
...
@@ -926,59 +918,58 @@ def scan( fn
# variables are put on GPU right aways >:| ,
# variables are put on GPU right aways >:| ,
new_givens
=
{}
new_givens
=
{}
for
w
,
w_copy
in
givens
.
iteritems
():
for
w
,
w_copy
in
givens
.
iteritems
():
if
(
isinstance
(
w
.
type
,
cuda
.
CudaNdarrayType
)
if
(
isinstance
(
w
.
type
,
cuda
.
CudaNdarrayType
)
and
isinstance
(
w_copy
.
type
,
tensor
.
TensorType
)):
and
isinstance
(
w_copy
.
type
,
tensor
.
TensorType
)):
for
o
in
inner_outs
:
for
o
in
inner_outs
:
new_givens
=
traverse
(
o
,
w
,
w_copy
,
new_givens
)
new_givens
=
traverse
(
o
,
w
,
w_copy
,
new_givens
)
else
:
else
:
new_givens
[
w
]
=
w_copy
new_givens
[
w
]
=
w_copy
else
:
else
:
new_givens
=
givens
new_givens
=
givens
new_outs
=
scan_utils
.
clone
(
inner_outs
,
replace
=
new_givens
)
new_outs
=
scan_utils
.
clone
(
inner_outs
,
replace
=
new_givens
)
##
##
### Step 7. Create the Scan Op
### Step 7. Create the Scan Op
##
##
tap_array
=
mit_sot_tap_array
+
[[
-
1
]
for
x
in
xrange
(
n_sit_sot
)]
tap_array
=
mit_sot_tap_array
+
[[
-
1
]
for
x
in
xrange
(
n_sit_sot
)]
info
=
{}
info
=
{}
info
[
'tap_array'
]
=
tap_array
info
[
'tap_array'
]
=
tap_array
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
info
[
'n_mit_sot'
]
=
n_mit_sot
info
[
'n_mit_sot'
]
=
n_mit_sot
info
[
'n_sit_sot'
]
=
n_sit_sot
info
[
'n_sit_sot'
]
=
n_sit_sot
info
[
'n_shared_outs'
]
=
n_shared_outs
info
[
'n_shared_outs'
]
=
n_shared_outs
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'truncate_gradient'
]
=
truncate_gradient
info
[
'truncate_gradient'
]
=
truncate_gradient
info
[
'name'
]
=
name
info
[
'name'
]
=
name
info
[
'mode'
]
=
mode
info
[
'mode'
]
=
mode
info
[
'inplace'
]
=
False
info
[
'inplace'
]
=
False
info
[
'gpu'
]
=
False
info
[
'gpu'
]
=
False
info
[
'as_while'
]
=
as_while
info
[
'as_while'
]
=
as_while
info
[
'profile'
]
=
profile
info
[
'profile'
]
=
profile
local_op
=
scan_op
.
Scan
(
inner_inputs
,
new_outs
,
info
)
local_op
=
scan_op
.
Scan
(
inner_inputs
,
new_outs
,
info
)
##
##
### Step 8. Compute the outputs using the scan op
### Step 8. Compute the outputs using the scan op
##
##
_scan_inputs
=
(
scan_seqs
+
_scan_inputs
=
(
scan_seqs
+
mit_mot_scan_inputs
+
mit_mot_scan_inputs
+
mit_sot_scan_inputs
+
mit_sot_scan_inputs
+
sit_sot_scan_inputs
+
sit_sot_scan_inputs
+
shared_scan_inputs
+
shared_scan_inputs
+
[
actual_n_steps
for
x
in
xrange
(
n_nit_sot
)
]
+
[
actual_n_steps
for
x
in
xrange
(
n_nit_sot
)
]
+
other_shared_scan_args
+
other_shared_scan_args
+
other_scan_args
)
other_scan_args
)
scan_inputs
=
[]
scan_inputs
=
[]
for
arg
in
[
actual_n_steps
]
+
_scan_inputs
:
for
arg
in
[
actual_n_steps
]
+
_scan_inputs
:
try
:
try
:
arg
=
tensor
.
as_tensor_variable
(
arg
)
arg
=
tensor
.
as_tensor_variable
(
arg
)
except
TypeError
:
except
TypeError
:
...
@@ -986,8 +977,8 @@ def scan( fn
...
@@ -986,8 +977,8 @@ def scan( fn
# to make sure no input is a cuda ndarrays
# to make sure no input is a cuda ndarrays
pass
pass
scan_inputs
+=
[
arg
]
scan_inputs
+=
[
arg
]
scan_outs
=
local_op
(
*
scan_inputs
)
scan_outs
=
local_op
(
*
scan_inputs
)
if
type
(
scan_outs
)
not
in
(
list
,
tuple
):
if
type
(
scan_outs
)
not
in
(
list
,
tuple
):
scan_outs
=
[
scan_outs
]
scan_outs
=
[
scan_outs
]
##
##
### Step 9. Figure out which outs are update rules for shared variables
### Step 9. Figure out which outs are update rules for shared variables
...
@@ -995,55 +986,57 @@ def scan( fn
...
@@ -995,55 +986,57 @@ def scan( fn
##
##
update_map
=
Updates
()
update_map
=
Updates
()
def
remove_dimensions
(
outs
,
steps_return
,
offsets
=
None
):
def
remove_dimensions
(
outs
,
steps_return
,
offsets
=
None
):
out_ls
=
[]
out_ls
=
[]
for
idx
,
out
in
enumerate
(
outs
):
for
idx
,
out
in
enumerate
(
outs
):
if
idx
in
steps_return
:
if
idx
in
steps_return
:
if
steps_return
[
idx
]
>
1
:
if
steps_return
[
idx
]
>
1
:
out_ls
.
append
(
out
[
-
steps_return
[
idx
]:]
)
out_ls
.
append
(
out
[
-
steps_return
[
idx
]:]
)
else
:
else
:
out_ls
.
append
(
out
[
-
1
]
)
out_ls
.
append
(
out
[
-
1
]
)
else
:
else
:
if
offsets
is
None
:
if
offsets
is
None
:
out_ls
.
append
(
out
)
out_ls
.
append
(
out
)
else
:
else
:
out_ls
.
append
(
out
[
offsets
[
idx
]:]
)
out_ls
.
append
(
out
[
offsets
[
idx
]:]
)
return
out_ls
return
out_ls
offset
=
n_mit_mot
offset
=
n_mit_mot
offsets
=
[
abs
(
numpy
.
min
(
x
))
for
x
in
mit_sot_tap_array
]
offsets
=
[
abs
(
numpy
.
min
(
x
))
for
x
in
mit_sot_tap_array
]
mit_sot_outs
=
remove_dimensions
(
mit_sot_outs
=
remove_dimensions
(
scan_outs
[
offset
:
offset
+
n_mit_sot
]
scan_outs
[
offset
:
offset
+
n_mit_sot
],
,
mit_sot_return_steps
mit_sot_return_steps
,
,
offsets
)
offsets
)
offset
+=
n_mit_sot
offset
+=
n_mit_sot
offsets
=
[
1
for
x
in
xrange
(
n_sit_sot
)
]
offsets
=
[
1
for
x
in
xrange
(
n_sit_sot
)
]
sit_sot_outs
=
remove_dimensions
(
sit_sot_outs
=
remove_dimensions
(
scan_outs
[
offset
:
offset
+
n_sit_sot
]
scan_outs
[
offset
:
offset
+
n_sit_sot
],
,
sit_sot_return_steps
sit_sot_return_steps
,
,
offsets
)
offsets
)
offset
+=
n_sit_sot
offset
+=
n_sit_sot
nit_sot_outs
=
remove_dimensions
(
nit_sot_outs
=
remove_dimensions
(
scan_outs
[
offset
:
offset
+
n_nit_sot
]
scan_outs
[
offset
:
offset
+
n_nit_sot
],
,
nit_sot_return_steps
)
nit_sot_return_steps
)
offset
+=
n_nit_sot
offset
+=
n_nit_sot
for
idx
,
update_rule
in
enumerate
(
scan_outs
[
offset
:
offset
+
n_shared_outs
]):
for
idx
,
update_rule
in
enumerate
(
scan_outs
[
offset
:
offset
+
n_shared_outs
]):
update_map
[
shared_scan_inputs
[
idx
]]
=
update_rule
update_map
[
shared_scan_inputs
[
idx
]]
=
update_rule
_scan_out_list
=
(
mit_sot_outs
+
_scan_out_list
=
(
mit_sot_outs
+
sit_sot_outs
+
sit_sot_outs
+
nit_sot_outs
)
nit_sot_outs
)
# Step 10. I need to reorder the outputs to be in the order expected by
# Step 10. I need to reorder the outputs to be in the order expected by
# the user
# the user
rightOrder
=
(
mit_sot_rightOrder
+
rightOrder
=
(
mit_sot_rightOrder
+
sit_sot_rightOrder
+
sit_sot_rightOrder
+
nit_sot_rightOrder
)
nit_sot_rightOrder
)
scan_out_list
=
[
None
]
*
len
(
rightOrder
)
scan_out_list
=
[
None
]
*
len
(
rightOrder
)
for
idx
,
pos
in
enumerate
(
rightOrder
):
for
idx
,
pos
in
enumerate
(
rightOrder
):
scan_out_list
[
pos
]
=
_scan_out_list
[
idx
]
scan_out_list
[
pos
]
=
_scan_out_list
[
idx
]
if
len
(
scan_out_list
)
==
1
:
if
len
(
scan_out_list
)
==
1
:
scan_out_list
=
scan_out_list
[
0
]
scan_out_list
=
scan_out_list
[
0
]
elif
len
(
scan_out_list
)
==
0
:
elif
len
(
scan_out_list
)
==
0
:
...
...
theano/scan_module/tests/test_scan.py
浏览文件 @
f4b29ab6
...
@@ -2481,6 +2481,26 @@ class T_Scan(unittest.TestCase):
...
@@ -2481,6 +2481,26 @@ class T_Scan(unittest.TestCase):
if
isinstance
(
x
.
op
,
theano
.
scan_module
.
scan_op
.
Scan
)]
if
isinstance
(
x
.
op
,
theano
.
scan_module
.
scan_op
.
Scan
)]
assert
len
(
lssc
)
==
1
assert
len
(
lssc
)
==
1
def
test_grad_multiple_seqs_different_nsteps
(
self
):
# Example provided Michael Forbes
# This test assures that we clip the sequences to n_steps before
# computing the gradient (so that when we reverse them we actually
# get the right values in
c
=
theano
.
tensor
.
vector
(
'c'
)
x
=
theano
.
tensor
.
scalar
(
'x'
)
_max_coefficients_supported
=
1000
full_range
=
theano
.
tensor
.
arange
(
_max_coefficients_supported
)
components
,
updates
=
theano
.
scan
(
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
outputs_info
=
None
,
sequences
=
[
c
,
full_range
],
non_sequences
=
x
)
P
=
components
.
sum
()
dP
=
theano
.
tensor
.
grad
(
P
,
x
)
tf
=
theano
.
function
([
c
,
x
],
dP
)
assert
tf
([
1.0
,
2.0
,
-
3.0
,
4.0
],
2.0
)
==
38
def
test_return_steps
(
self
):
def
test_return_steps
(
self
):
rng
=
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
rng
=
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
vW_in2
=
asarrayX
(
rng
.
uniform
(
size
=
(
2
,),
low
=
-
5.
,
high
=
5.
))
vW_in2
=
asarrayX
(
rng
.
uniform
(
size
=
(
2
,),
low
=
-
5.
,
high
=
5.
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论