Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
fc22aaf9
提交
fc22aaf9
authored
10月 28, 2011
作者:
Razvan Pascanu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
making scan.py PEP8 compatible
上级
80f150b2
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
254 行增加
和
262 行删除
+254
-262
scan.py
theano/scan_module/scan.py
+254
-262
没有找到文件。
theano/scan_module/scan.py
浏览文件 @
fc22aaf9
...
...
@@ -34,10 +34,10 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
``foldr()``.
"""
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
)
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
...
...
@@ -63,16 +63,16 @@ from scan_utils import safe_new, traverse
_logger
=
logging
.
getLogger
(
'theano.scan_module.scan'
)
def
scan
(
fn
,
sequences
=
None
,
outputs_info
=
None
,
non_sequences
=
None
,
n_steps
=
None
,
truncate_gradient
=
-
1
,
go_backwards
=
False
,
mode
=
None
,
name
=
None
,
profile
=
False
):
def
scan
(
fn
,
sequences
=
None
,
outputs_info
=
None
,
non_sequences
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
mode
=
None
,
name
=
None
,
profile
=
False
):
"""
This function constructs and applies a Scan op to the provided
arguments.
...
...
@@ -333,12 +333,12 @@ def scan( fn
'''
if
x
is
None
:
return
[]
elif
not
isinstance
(
x
,
(
list
,
tuple
)):
elif
not
isinstance
(
x
,
(
list
,
tuple
)):
return
[
x
]
else
:
return
list
(
x
)
seqs
=
wrap_into_list
(
sequences
)
seqs
=
wrap_into_list
(
sequences
)
outs_info
=
wrap_into_list
(
outputs_info
)
# Make sure we get rid of numpy arrays or ints or anything like that
...
...
@@ -356,19 +356,19 @@ def scan( fn
# To do that we check here to see the nature of n_steps
n_fixed_steps
=
None
if
isinstance
(
n_steps
,
(
float
,
int
)):
if
isinstance
(
n_steps
,
(
float
,
int
)):
n_fixed_steps
=
int
(
n_steps
)
else
:
try
:
try
:
n_fixed_steps
=
opt
.
get_constant_value
(
n_steps
)
except
(
TypeError
,
AttributeError
):
n_fixed_steps
=
None
# Check n_steps is an int
if
(
hasattr
(
n_steps
,
'dtype'
)
and
str
(
n_steps
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
):
if
(
hasattr
(
n_steps
,
'dtype'
)
and
str
(
n_steps
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
):
raise
ValueError
(
' n_steps must be an int. dtype provided '
'is
%
s'
%
n_steps
.
dtype
)
'is
%
s'
%
n_steps
.
dtype
)
# compute number of sequences and number of outputs
n_seqs
=
len
(
seqs
)
...
...
@@ -377,11 +377,11 @@ def scan( fn
return_steps
=
{}
# wrap sequences in a dictionary if they are not already dictionaries
for
i
in
xrange
(
n_seqs
):
if
not
isinstance
(
seqs
[
i
],
dict
)
:
if
not
isinstance
(
seqs
[
i
],
dict
):
seqs
[
i
]
=
dict
(
input
=
seqs
[
i
],
taps
=
[
0
])
elif
seqs
[
i
]
.
get
(
'taps'
,
None
):
elif
seqs
[
i
]
.
get
(
'taps'
,
None
):
seqs
[
i
][
'taps'
]
=
wrap_into_list
(
seqs
[
i
][
'taps'
])
elif
seqs
[
i
]
.
get
(
'taps'
,
True
)
is
None
:
elif
seqs
[
i
]
.
get
(
'taps'
,
True
)
is
None
:
# seqs dictionary does not have the ``taps`` key
seqs
[
i
][
'taps'
]
=
[
0
]
...
...
@@ -391,30 +391,31 @@ def scan( fn
if
isinstance
(
outs_info
[
i
],
dict
):
# DEPRICATED :
if
outs_info
[
i
]
.
get
(
'return_steps'
,
None
):
_logger
.
warning
(
(
"Using `return_steps` has been depricated."
" Simply select the entries you need using "
" a subtensor. Scan will optimize memory "
" consumption, so do not worry about that."
))
_logger
.
warning
((
"Using `return_steps` has been "
"depricated. Simply select the entries you "
"need using a subtensor. Scan will optimize "
"memory consumption, so do not worry about "
"that."
))
return_steps
[
i
]
=
outs_info
[
i
][
'return_steps'
]
# END
if
not
isinstance
(
outs_info
[
i
],
dict
):
# by default any output has a tap value of -1
outs_info
[
i
]
=
dict
(
initial
=
outs_info
[
i
],
taps
=
[
-
1
])
elif
(
not
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
outs_info
[
i
]
=
dict
(
initial
=
outs_info
[
i
],
taps
=
[
-
1
])
elif
(
not
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
# ^ no initial state but taps provided
raise
ValueError
(
(
'If you are using slices of an output '
'you need to provide a initial state '
'for it'
),
outs_info
[
i
]
)
elif
(
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
not
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
raise
ValueError
(
(
'If you are using slices of an output '
'you need to provide a initial state '
'for it'
),
outs_info
[
i
]
)
elif
(
outs_info
[
i
]
.
get
(
'initial'
,
None
)
and
not
outs_info
[
i
]
.
get
(
'taps'
,
None
)):
# ^ initial state but taps not provided
if
outs_info
[
i
]
.
has_key
(
'taps'
)
:
if
'taps'
in
outs_info
[
i
]
:
# ^ explicitly provided a None for taps
_logger
.
warning
(
'Output
%
s ( index
%
d) has a initial
state
'
'
but taps is explicitly set to None '
,
getattr
(
outs_info
[
i
][
'initial'
],
'name'
,
'None'
),
_logger
.
warning
(
'Output
%
s ( index
%
d) has a initial '
'state
but taps is explicitly set to None '
,
getattr
(
outs_info
[
i
][
'initial'
],
'name'
,
'None'
),
i
)
outs_info
[
i
][
'taps'
]
=
[
-
1
]
else
:
...
...
@@ -434,12 +435,12 @@ def scan( fn
# and to construct a new and complete list of inputs and
# outputs
n_seqs
=
0
scan_seqs
=
[]
# Variables passed as inputs to the scan op
inner_seqs
=
[]
# Variables passed as inputs to the inner function
inner_slices
=
[]
# Actual slices if scan is removed from the picture
n_seqs
=
0
scan_seqs
=
[]
# Variables passed as inputs to the scan op
inner_seqs
=
[]
# Variables passed as inputs to the inner function
inner_slices
=
[]
# Actual slices if scan is removed from the picture
# go through sequences picking up time slices as needed
for
i
,
seq
in
enumerate
(
seqs
):
for
i
,
seq
in
enumerate
(
seqs
):
# Note that you can have something like no taps for
# a sequence, though is highly unlikely in practice
if
'taps'
in
seq
:
...
...
@@ -456,31 +457,33 @@ def scan( fn
# If not we need to use copies, that will be replaced at
# each frame by the corresponding slice
actual_slice
=
seq
[
'input'
][
k
-
mintap
]
actual_slice
=
seq
[
'input'
][
k
-
mintap
]
_seq_val
=
tensor
.
as_tensor_variable
(
seq
[
'input'
])
_seq_val_slice
=
_seq_val
[
k
-
mintap
]
_seq_val_slice
=
_seq_val
[
k
-
mintap
]
nw_slice
=
_seq_val_slice
.
type
()
# Try to transfer test_value to the new variable
if
config
.
compute_test_value
!=
'off'
:
try
:
nw_slice
.
tag
.
test_value
=
gof
.
Op
.
_get_test_value
(
_seq_val_slice
)
nw_slice
.
tag
.
test_value
=
gof
.
Op
.
_get_test_value
(
_seq_val_slice
)
except
AttributeError
,
e
:
if
config
.
compute_test_value
!=
'ignore'
:
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger
.
info
((
'Cannot compute test value for the inner '
'function of scan, input value missing
%
s'
),
e
)
_logger
.
info
((
'Cannot compute test value for '
'the inner function of scan, input value '
'missing
%
s'
),
e
)
# Add names to slices for debugging and pretty printing ..
# that is if the input already has a name
if
getattr
(
seq
[
'input'
],
'name'
,
None
)
is
not
None
:
if
getattr
(
seq
[
'input'
],
'name'
,
None
)
is
not
None
:
if
k
>
0
:
nw_name
=
seq
[
'input'
]
.
name
+
'[t+
%
d]'
%
k
nw_name
=
seq
[
'input'
]
.
name
+
'[t+
%
d]'
%
k
elif
k
==
0
:
nw_name
=
seq
[
'input'
]
.
name
+
'[t]'
else
:
nw_name
=
seq
[
'input'
]
.
name
+
'[t
%
d]'
%
k
nw_name
=
seq
[
'input'
]
.
name
+
'[t
%
d]'
%
k
nw_slice
.
name
=
nw_name
# We cut the sequence such that seq[i] to correspond to
...
...
@@ -490,34 +493,30 @@ def scan( fn
else
:
offset
=
0
if
maxtap
==
mintap
and
maxtap
!=
0
:
nw_seq
=
seq
[
'input'
][:
abs
(
maxtap
)]
elif
maxtap
-
k
!=
0
:
nw_seq
=
seq
[
'input'
][
offset
+
k
-
mintap
:
-
(
maxtap
-
k
)]
nw_seq
=
seq
[
'input'
][:
abs
(
maxtap
)]
elif
maxtap
-
k
!=
0
:
nw_seq
=
seq
[
'input'
][
offset
+
k
-
mintap
:
-
(
maxtap
-
k
)]
else
:
nw_seq
=
seq
[
'input'
][
offset
+
k
-
mintap
:
]
nw_seq
=
seq
[
'input'
][
offset
+
k
-
mintap
:
]
if
go_backwards
:
nw_seq
=
nw_seq
[::
-
1
]
scan_seqs
.
append
(
nw_seq
)
inner_seqs
.
append
(
nw_slice
)
inner_slices
.
append
(
actual_slice
)
scan_seqs
.
append
(
nw_seq
)
inner_seqs
.
append
(
nw_slice
)
inner_slices
.
append
(
actual_slice
)
n_seqs
+=
1
# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
lengths_vec
=
[]
for
seq
in
scan_seqs
:
lengths_vec
.
append
(
seq
.
shape
[
0
]
)
lengths_vec
.
append
(
seq
.
shape
[
0
]
)
if
not
scan_utils
.
isNaN_or_Inf_or_None
(
n_steps
):
# ^ N_steps should also be considered
lengths_vec
.
append
(
tensor
.
as_tensor
(
n_steps
)
)
lengths_vec
.
append
(
tensor
.
as_tensor
(
n_steps
))
if
len
(
lengths_vec
)
==
0
:
if
len
(
lengths_vec
)
==
0
:
# ^ No information about the number of steps
raise
ValueError
(
' No information about the number of steps '
'provided. Either provide a value for '
...
...
@@ -534,11 +533,12 @@ def scan( fn
actual_n_steps
=
tensor
.
as_tensor
(
n_steps
)
# Add names -- it helps a lot when debugging
for
(
nw_seq
,
seq
)
in
zip
(
scan_seqs
,
seqs
):
if
getattr
(
seq
[
'input'
],
'name'
,
None
)
is
not
None
:
nw_seq
.
name
=
seq
[
'input'
]
.
name
+
'[
%
d:]'
%
k
if
getattr
(
seq
[
'input'
],
'name'
,
None
)
is
not
None
:
nw_seq
.
name
=
seq
[
'input'
]
.
name
+
'[
%
d:]'
%
k
scan_seqs
=
[
seq
[:
actual_n_steps
]
for
seq
in
scan_seqs
]
scan_seqs
=
[
seq
[:
actual_n_steps
]
for
seq
in
scan_seqs
]
# Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided
# by the gradient function )
...
...
@@ -546,39 +546,35 @@ def scan( fn
# sit_sot = single input tap, single output tap (t + 0)
# nit_sot = no input tap, single output tap (t + 0)
# MIT_MOT -- not provided by the user only by the grad function
n_mit_mot
=
0
n_mit_mot_outs
=
0
mit_mot_scan_inputs
=
[]
mit_mot_inner_inputs
=
[]
n_mit_mot
=
0
n_mit_mot_outs
=
0
mit_mot_scan_inputs
=
[]
mit_mot_inner_inputs
=
[]
mit_mot_inner_outputs
=
[]
mit_mot_out_slices
=
[]
mit_mot_rightOrder
=
[]
mit_mot_out_slices
=
[]
mit_mot_rightOrder
=
[]
# SIT_SOT -- provided by the user
n_mit_sot
=
0
mit_sot_scan_inputs
=
[]
mit_sot_inner_inputs
=
[]
mit_sot_inner_slices
=
[]
n_mit_sot
=
0
mit_sot_scan_inputs
=
[]
mit_sot_inner_inputs
=
[]
mit_sot_inner_slices
=
[]
mit_sot_inner_outputs
=
[]
mit_sot_return_steps
=
{}
mit_sot_tap_array
=
[]
mit_sot_rightOrder
=
[]
n_sit_sot
=
0
sit_sot_scan_inputs
=
[]
sit_sot_inner_inputs
=
[]
sit_sot_inner_slices
=
[]
mit_sot_return_steps
=
{}
mit_sot_tap_array
=
[]
mit_sot_rightOrder
=
[]
n_sit_sot
=
0
sit_sot_scan_inputs
=
[]
sit_sot_inner_inputs
=
[]
sit_sot_inner_slices
=
[]
sit_sot_inner_outputs
=
[]
sit_sot_return_steps
=
{}
sit_sot_rightOrder
=
[]
sit_sot_return_steps
=
{}
sit_sot_rightOrder
=
[]
# go through outputs picking up time slices as needed
for
i
,
init_out
in
enumerate
(
outs_info
):
for
i
,
init_out
in
enumerate
(
outs_info
):
# Note that our convention dictates that if an output uses
# just the previous time step, as a initial state we will only
# provide a tensor of the same dimension as one time step; This
...
...
@@ -602,11 +598,12 @@ def scan( fn
if
config
.
compute_test_value
!=
'ignore'
:
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger
.
info
((
'Cannot compute test value for the inner '
'function of scan, input value missing
%
s'
),
e
)
_logger
.
info
((
'Cannot compute test value for the '
'inner function of scan, input value missing
%
s'
),
e
)
if
getattr
(
init_out
[
'initial'
],
'name'
,
None
)
is
not
None
:
arg
.
name
=
init_out
[
'initial'
]
.
name
+
'[t-1]'
if
getattr
(
init_out
[
'initial'
],
'name'
,
None
)
is
not
None
:
arg
.
name
=
init_out
[
'initial'
]
.
name
+
'[t-1]'
# We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function
...
...
@@ -614,117 +611,119 @@ def scan( fn
sit_sot_scan_inputs
.
append
(
scan_utils
.
expand
(
tensor
.
unbroadcast
(
tensor
.
shape_padleft
(
actual_arg
),
0
)
,
actual_n_steps
)
)
tensor
.
shape_padleft
(
actual_arg
),
0
)
,
actual_n_steps
))
sit_sot_inner_slices
.
append
(
actual_arg
)
if
i
in
return_steps
:
sit_sot_return_steps
[
n_sit_sot
]
=
return_steps
[
i
]
sit_sot_inner_inputs
.
append
(
arg
)
sit_sot_rightOrder
.
append
(
i
)
sit_sot_inner_inputs
.
append
(
arg
)
sit_sot_rightOrder
.
append
(
i
)
n_sit_sot
+=
1
elif
init_out
.
get
(
'taps'
,
None
):
elif
init_out
.
get
(
'taps'
,
None
):
if
numpy
.
any
(
numpy
.
array
(
init_out
.
get
(
'taps'
,[]))
>
0
):
if
numpy
.
any
(
numpy
.
array
(
init_out
.
get
(
'taps'
,
[]))
>
0
):
# Make sure we do not have requests for future values of a
# sequence we can not provide such values
raise
ValueError
(
'Can not use future taps of outputs'
,
init_out
)
raise
ValueError
(
'Can not use future taps of outputs'
,
init_out
)
# go through the taps
mintap
=
abs
(
numpy
.
min
(
init_out
[
'taps'
]))
mit_sot_tap_array
.
append
(
init_out
[
'taps'
]
)
mit_sot_tap_array
.
append
(
init_out
[
'taps'
]
)
idx_offset
=
abs
(
numpy
.
min
(
init_out
[
'taps'
]))
# Sequence
mit_sot_scan_inputs
.
append
(
scan_utils
.
expand
(
init_out
[
'initial'
][:
mintap
]
,
actual_n_steps
)
)
scan_utils
.
expand
(
init_out
[
'initial'
][:
mintap
],
actual_n_steps
)
)
if
i
in
return_steps
:
mit_sot_return_steps
[
n_mit_sot
]
=
return_steps
[
i
]
mit_sot_rightOrder
.
append
(
i
)
mit_sot_rightOrder
.
append
(
i
)
n_mit_sot
+=
1
for
k
in
init_out
[
'taps'
]:
# create a new slice
actual_nw_slice
=
init_out
[
'initial'
][
k
+
mintap
]
actual_nw_slice
=
init_out
[
'initial'
][
k
+
mintap
]
_init_out_var
=
tensor
.
as_tensor_variable
(
init_out
[
'initial'
])
_init_out_var_slice
=
_init_out_var
[
k
+
mintap
]
_init_out_var_slice
=
_init_out_var
[
k
+
mintap
]
nw_slice
=
_init_out_var_slice
.
type
()
# Try to transfer test_value to the new variable
if
config
.
compute_test_value
!=
'off'
:
try
:
nw_slice
.
tag
.
test_value
=
gof
.
Op
.
_get_test_value
(
_init_out_var_slice
)
nw_slice
.
tag
.
test_value
=
gof
.
Op
.
_get_test_value
(
_init_out_var_slice
)
except
AttributeError
,
e
:
if
config
.
compute_test_value
!=
'ignore'
:
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger
.
info
((
'Cannot compute test value for the inner '
'function of scan, input value missing.
%
s'
),
e
)
_logger
.
info
((
'Cannot compute test value for '
'the inner function of scan, input value '
'missing.
%
s'
),
e
)
# give it a name or debugging and pretty printing
if
getattr
(
init_out
[
'initial'
],
'name'
,
None
)
is
not
None
:
if
getattr
(
init_out
[
'initial'
],
'name'
,
None
)
is
not
None
:
if
k
>
0
:
nw_slice
.
name
=
(
init_out
[
'initial'
]
.
name
+
'[t+
%
d]'
%
k
)
nw_slice
.
name
=
(
init_out
[
'initial'
]
.
name
+
'[t+
%
d]'
%
k
)
elif
k
==
0
:
nw_slice
.
name
=
init_out
[
'initial'
]
.
name
+
'[t]'
else
:
nw_slice
.
name
=
(
init_out
[
'initial'
]
.
name
+
'[t
%
d]'
%
k
)
mit_sot_inner_inputs
.
append
(
nw_slice
)
mit_sot_inner_slices
.
append
(
actual_nw_slice
)
nw_slice
.
name
=
(
init_out
[
'initial'
]
.
name
+
'[t
%
d]'
%
k
)
mit_sot_inner_inputs
.
append
(
nw_slice
)
mit_sot_inner_slices
.
append
(
actual_nw_slice
)
#NOTE: there is another case, in which we do not want to provide
# any previous value of the output to the inner function (i.e.
# a map); in that case we do not have to do anything ..
# Re-order args
max_mit_sot
=
numpy
.
max
(
[
-
1
]
+
mit_sot_rightOrder
)
+
1
max_sit_sot
=
numpy
.
max
(
[
-
1
]
+
sit_sot_rightOrder
)
+
1
n_elems
=
numpy
.
max
(
[
max_mit_sot
,
max_sit_sot
]
)
max_mit_sot
=
numpy
.
max
(
[
-
1
]
+
mit_sot_rightOrder
)
+
1
max_sit_sot
=
numpy
.
max
(
[
-
1
]
+
sit_sot_rightOrder
)
+
1
n_elems
=
numpy
.
max
([
max_mit_sot
,
max_sit_sot
]
)
_ordered_args
=
[[]
for
x
in
xrange
(
n_elems
)]
offset
=
0
for
idx
in
xrange
(
n_mit_sot
):
n_inputs
=
len
(
mit_sot_tap_array
[
idx
])
if
n_fixed_steps
in
[
1
,
-
1
]:
if
n_fixed_steps
in
[
1
,
-
1
]:
_ordered_args
[
mit_sot_rightOrder
[
idx
]]
=
\
mit_sot_inner_slices
[
offset
:
offset
+
n_inputs
]
mit_sot_inner_slices
[
offset
:
offset
+
n_inputs
]
else
:
_ordered_args
[
mit_sot_rightOrder
[
idx
]]
=
\
mit_sot_inner_inputs
[
offset
:
offset
+
n_inputs
]
mit_sot_inner_inputs
[
offset
:
offset
+
n_inputs
]
offset
+=
n_inputs
for
idx
in
xrange
(
n_sit_sot
):
if
n_fixed_steps
in
[
1
,
-
1
]:
if
n_fixed_steps
in
[
1
,
-
1
]:
_ordered_args
[
sit_sot_rightOrder
[
idx
]]
=
\
[
sit_sot_inner_slices
[
idx
]
]
[
sit_sot_inner_slices
[
idx
]
]
else
:
_ordered_args
[
sit_sot_rightOrder
[
idx
]]
=
\
[
sit_sot_inner_inputs
[
idx
]
]
[
sit_sot_inner_inputs
[
idx
]
]
ordered_args
=
[]
for
ls
in
_ordered_args
:
ordered_args
+=
ls
if
n_fixed_steps
in
[
1
,
-
1
]:
if
n_fixed_steps
in
[
1
,
-
1
]:
args
=
(
inner_slices
+
ordered_args
+
non_seqs
)
non_seqs
)
else
:
args
=
(
inner_seqs
+
args
=
(
inner_seqs
+
ordered_args
+
non_seqs
)
non_seqs
)
# add only the non-shared variables and non-constants to the arguments of the dummy
# function [ a function should not get shared variables or constants as input ]
# add only the non-shared variables and non-constants to the arguments of
# the dummy function [ a function should not get shared variables or
# constants as input ]
dummy_args
=
[
arg
for
arg
in
args
if
(
not
isinstance
(
arg
,
SharedVariable
)
and
not
isinstance
(
arg
,
tensor
.
Constant
)
)]
not
isinstance
(
arg
,
tensor
.
Constant
))]
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
condition
,
outputs
,
updates
=
scan_utils
.
get_updates_and_outputs
(
fn
(
*
args
))
if
condition
is
not
None
:
as_while
=
True
...
...
@@ -734,14 +733,13 @@ def scan( fn
### Step 3. Check if we actually need scan and remove it if we don't
##
if
n_fixed_steps
in
[
1
,
-
1
]:
# We do not need to use the scan op anymore, so we can just return
# the outputs and updates we have
if
condition
is
not
None
:
_logger
.
warning
(
(
'When the number of steps is fixed and equal to 1,
'
' the provided stopping condition, '
,
str
(
condition
)
,
' is ignored'
))
_logger
.
warning
(
(
'When the number of steps is fixed and equal
'
'to 1, the provided stopping condition, '
,
str
(
condition
),
' is ignored'
))
for
pos
,
inner_out
in
enumerate
(
outputs
):
# we need to see if we need to pad our sequences with an
...
...
@@ -750,16 +748,15 @@ def scan( fn
# then, if we return the output as given by the innner function
# this will represent only a slice and it will have one
# dimension less.
if
(
isinstance
(
inner_out
.
type
,
tensor
.
TensorType
)
and
if
(
isinstance
(
inner_out
.
type
,
tensor
.
TensorType
)
and
return_steps
.
get
(
pos
,
0
)
!=
1
):
outputs
[
pos
]
=
tensor
.
unbroadcast
(
tensor
.
shape_padleft
(
inner_out
),
0
)
tensor
.
shape_padleft
(
inner_out
),
0
)
if
len
(
outputs
)
==
1
:
outputs
=
outputs
[
0
]
return
(
outputs
,
updates
)
##
### Step 4. Compile the dummy function
##
...
...
@@ -779,11 +776,11 @@ def scan( fn
replace
=
dict
(
zip
(
non_seqs
,
fake_nonseqs
)))
all_inputs
=
itertools
.
ifilter
(
lambda
x
:
(
isinstance
(
x
,
gof
.
Variable
)
and
lambda
x
:
(
isinstance
(
x
,
gof
.
Variable
)
and
not
isinstance
(
x
,
SharedVariable
)
and
not
isinstance
(
x
,
gof
.
Constant
)
),
gof
.
graph
.
inputs
(
fake_outputs
)
)
extra_inputs
=
filter
(
lambda
x
:
x
not
in
args
+
fake_nonseqs
,
not
isinstance
(
x
,
gof
.
Constant
)),
gof
.
graph
.
inputs
(
fake_outputs
)
)
extra_inputs
=
filter
(
lambda
x
:
x
not
in
args
+
fake_nonseqs
,
all_inputs
)
non_seqs
+=
extra_inputs
## Note we do not use all_inputs directly since the order of variables
...
...
@@ -793,12 +790,11 @@ def scan( fn
dummy_outs
=
outputs
if
condition
is
not
None
:
dummy_outs
.
append
(
condition
)
dummy_f
=
function
(
dummy_args
,
dummy_outs
,
updates
=
updates
,
mode
=
compile
.
mode
.
Mode
(
linker
=
'py'
,
optimizer
=
None
)
)
dummy_f
=
function
(
dummy_args
,
dummy_outs
,
updates
=
updates
,
mode
=
compile
.
mode
.
Mode
(
linker
=
'py'
,
optimizer
=
None
))
##
### Step 5. Re-arange inputs of scan into a more strict order
...
...
@@ -807,7 +803,6 @@ def scan( fn
## Step 5.0 Check the outputs of the dummy function to see if they
## match with user provided data
# if the number of outputs to the function does not match the number of
# assumed outputs until now (provided by the user) there can be
# only one explanation: No information is provided for any of the
...
...
@@ -815,7 +810,7 @@ def scan( fn
tmp_dummy_f_outs
=
len
(
dummy_f
.
maker
.
outputs
)
if
as_while
:
tmp_dummy_f_outs
-=
1
if
not
(
tmp_dummy_f_outs
==
n_outs
or
outs_info
==
[]):
if
not
(
tmp_dummy_f_outs
==
n_outs
or
outs_info
==
[]):
raise
ValueError
(
'Please provide None as output_info for '
'any output that does not feed back into '
'scan (i.e. it behaves like a map) '
)
...
...
@@ -824,95 +819,91 @@ def scan( fn
n_outs
=
len
(
dummy_f
.
maker
.
outputs
)
if
as_while
:
n_outs
=
n_outs
-
1
outs_info
=
[
dict
()
for
x
in
xrange
(
n_outs
)
]
outs_info
=
[
dict
()
for
x
in
xrange
(
n_outs
)]
## Step 5.1 Outputs with taps different then -1
for
i
,
out
in
enumerate
(
outs_info
):
if
'taps'
in
out
and
out
[
'taps'
]
!=
[
-
1
]:
mit_sot_inner_outputs
.
append
(
outputs
[
i
])
mit_sot_inner_outputs
.
append
(
outputs
[
i
])
## Step 5.2 Outputs with tap equal to -1
for
i
,
out
in
enumerate
(
outs_info
):
if
'taps'
in
out
and
out
[
'taps'
]
==
[
-
1
]:
sit_sot_inner_outputs
.
append
(
outputs
[
i
]
)
sit_sot_inner_outputs
.
append
(
outputs
[
i
])
## Step 5.3 Outputs that correspond to update rules of shared variables
givens
=
{}
n_shared_outs
=
0
shared_scan_inputs
=
[]
shared_inner_inputs
=
[]
givens
=
{}
n_shared_outs
=
0
shared_scan_inputs
=
[]
shared_inner_inputs
=
[]
shared_inner_outputs
=
[]
for
input
in
dummy_f
.
maker
.
expanded_inputs
:
if
isinstance
(
input
.
variable
,
SharedVariable
)
and
input
.
update
:
new_var
=
safe_new
(
input
.
variable
)
if
getattr
(
input
.
variable
,
'name'
,
None
)
is
not
None
:
if
getattr
(
input
.
variable
,
'name'
,
None
)
is
not
None
:
new_var
.
name
=
input
.
variable
.
name
+
'_copy'
shared_inner_inputs
.
append
(
new_var
)
shared_scan_inputs
.
append
(
input
.
variable
)
shared_inner_outputs
.
append
(
input
.
update
)
shared_inner_inputs
.
append
(
new_var
)
shared_scan_inputs
.
append
(
input
.
variable
)
shared_inner_outputs
.
append
(
input
.
update
)
givens
[
input
.
variable
]
=
new_var
n_shared_outs
+=
1
## Step 5.4 Outputs with no taps used in the input
n_nit_sot
=
0
n_nit_sot
=
0
nit_sot_inner_outputs
=
[]
nit_sot_return_steps
=
{}
nit_sot_rightOrder
=
[]
for
i
,
out
in
enumerate
(
outs_info
):
nit_sot_return_steps
=
{}
nit_sot_rightOrder
=
[]
for
i
,
out
in
enumerate
(
outs_info
):
if
not
'taps'
in
out
:
nit_sot_inner_outputs
.
append
(
outputs
[
i
]
)
nit_sot_inner_outputs
.
append
(
outputs
[
i
]
)
if
i
in
return_steps
:
nit_sot_return_steps
[
n_nit_sot
]
=
return_steps
[
i
]
nit_sot_rightOrder
.
append
(
i
)
nit_sot_rightOrder
.
append
(
i
)
n_nit_sot
+=
1
## Step 5.5 all other arguments including extra inputs
other_scan_args
=
[]
other_scan_args
=
[]
other_inner_args
=
[]
other_scan_args
+=
[
arg
for
arg
in
non_seqs
other_scan_args
+=
[
arg
for
arg
in
non_seqs
if
(
not
isinstance
(
arg
,
SharedVariable
)
and
not
isinstance
(
arg
,
tensor
.
Constant
))]
## Step 5.6 all shared variables with no update rules
other_inner_args
+=
[
safe_new
(
arg
,
'_copy'
)
for
arg
in
non_seqs
other_inner_args
+=
[
safe_new
(
arg
,
'_copy'
)
for
arg
in
non_seqs
if
(
not
isinstance
(
arg
,
SharedVariable
)
and
not
isinstance
(
arg
,
tensor
.
Constant
))]
givens
.
update
(
dict
(
zip
(
other_scan_args
,
other_inner_args
)
))
other_shared_scan_args
=
[
arg
.
variable
for
arg
givens
.
update
(
dict
(
zip
(
other_scan_args
,
other_inner_args
)
))
other_shared_scan_args
=
[
arg
.
variable
for
arg
in
dummy_f
.
maker
.
expanded_inputs
if
(
isinstance
(
arg
.
variable
,
SharedVariable
)
and
not
arg
.
update
)
]
other_shared_inner_args
=
[
safe_new
(
arg
.
variable
,
'_copy'
)
for
arg
if
(
isinstance
(
arg
.
variable
,
SharedVariable
)
and
not
arg
.
update
)]
other_shared_inner_args
=
[
safe_new
(
arg
.
variable
,
'_copy'
)
for
arg
in
dummy_f
.
maker
.
expanded_inputs
if
(
isinstance
(
arg
.
variable
,
SharedVariable
)
and
not
arg
.
update
)
]
givens
.
update
(
dict
(
zip
(
other_shared_scan_args
,
other_shared_inner_args
)
)
)
if
(
isinstance
(
arg
.
variable
,
SharedVariable
)
and
not
arg
.
update
)]
givens
.
update
(
dict
(
zip
(
other_shared_scan_args
,
other_shared_inner_args
)))
##
### Step 6. Re-order the outputs and clone them replacing things
### using the givens
##
inner_inputs
=
(
inner_seqs
+
mit_mot_inner_inputs
+
mit_sot_inner_inputs
+
sit_sot_inner_inputs
+
shared_inner_inputs
+
inner_inputs
=
(
inner_seqs
+
mit_mot_inner_inputs
+
mit_sot_inner_inputs
+
sit_sot_inner_inputs
+
shared_inner_inputs
+
other_shared_inner_args
+
other_inner_args
)
other_inner_args
)
inner_outs
=
(
mit_mot_inner_outputs
+
mit_sot_inner_outputs
+
sit_sot_inner_outputs
+
nit_sot_inner_outputs
+
shared_inner_outputs
)
inner_outs
=
(
mit_mot_inner_outputs
+
mit_sot_inner_outputs
+
sit_sot_inner_outputs
+
nit_sot_inner_outputs
+
shared_inner_outputs
)
if
condition
is
not
None
:
inner_outs
.
append
(
condition
)
# Cuda is imported here, instead of being imported on top of the file
...
...
@@ -927,59 +918,58 @@ def scan( fn
# variables are put on GPU right aways >:| ,
new_givens
=
{}
for
w
,
w_copy
in
givens
.
iteritems
():
for
w
,
w_copy
in
givens
.
iteritems
():
if
(
isinstance
(
w
.
type
,
cuda
.
CudaNdarrayType
)
and
isinstance
(
w_copy
.
type
,
tensor
.
TensorType
)):
for
o
in
inner_outs
:
new_givens
=
traverse
(
o
,
w
,
w_copy
,
new_givens
)
new_givens
=
traverse
(
o
,
w
,
w_copy
,
new_givens
)
else
:
new_givens
[
w
]
=
w_copy
else
:
new_givens
=
givens
new_outs
=
scan_utils
.
clone
(
inner_outs
,
replace
=
new_givens
)
new_outs
=
scan_utils
.
clone
(
inner_outs
,
replace
=
new_givens
)
##
### Step 7. Create the Scan Op
##
tap_array
=
mit_sot_tap_array
+
[[
-
1
]
for
x
in
xrange
(
n_sit_sot
)]
info
=
{}
info
=
{}
info
[
'tap_array'
]
=
tap_array
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'tap_array'
]
=
tap_array
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
info
[
'n_mit_sot'
]
=
n_mit_sot
info
[
'n_sit_sot'
]
=
n_sit_sot
info
[
'n_shared_outs'
]
=
n_shared_outs
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'truncate_gradient'
]
=
truncate_gradient
info
[
'name'
]
=
name
info
[
'mode'
]
=
mode
info
[
'inplace'
]
=
False
info
[
'gpu'
]
=
False
info
[
'as_while'
]
=
as_while
info
[
'profile'
]
=
profile
local_op
=
scan_op
.
Scan
(
inner_inputs
,
new_outs
,
info
)
info
[
'n_mit_sot'
]
=
n_mit_sot
info
[
'n_sit_sot'
]
=
n_sit_sot
info
[
'n_shared_outs'
]
=
n_shared_outs
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'truncate_gradient'
]
=
truncate_gradient
info
[
'name'
]
=
name
info
[
'mode'
]
=
mode
info
[
'inplace'
]
=
False
info
[
'gpu'
]
=
False
info
[
'as_while'
]
=
as_while
info
[
'profile'
]
=
profile
local_op
=
scan_op
.
Scan
(
inner_inputs
,
new_outs
,
info
)
##
### Step 8. Compute the outputs using the scan op
##
_scan_inputs
=
(
scan_seqs
+
mit_mot_scan_inputs
+
mit_sot_scan_inputs
+
sit_sot_scan_inputs
+
shared_scan_inputs
+
[
actual_n_steps
for
x
in
xrange
(
n_nit_sot
)
]
+
other_shared_scan_args
+
other_scan_args
)
_scan_inputs
=
(
scan_seqs
+
mit_mot_scan_inputs
+
mit_sot_scan_inputs
+
sit_sot_scan_inputs
+
shared_scan_inputs
+
[
actual_n_steps
for
x
in
xrange
(
n_nit_sot
)
]
+
other_shared_scan_args
+
other_scan_args
)
scan_inputs
=
[]
for
arg
in
[
actual_n_steps
]
+
_scan_inputs
:
for
arg
in
[
actual_n_steps
]
+
_scan_inputs
:
try
:
arg
=
tensor
.
as_tensor_variable
(
arg
)
except
TypeError
:
...
...
@@ -987,8 +977,8 @@ def scan( fn
# to make sure no input is a cuda ndarrays
pass
scan_inputs
+=
[
arg
]
scan_outs
=
local_op
(
*
scan_inputs
)
if
type
(
scan_outs
)
not
in
(
list
,
tuple
):
scan_outs
=
local_op
(
*
scan_inputs
)
if
type
(
scan_outs
)
not
in
(
list
,
tuple
):
scan_outs
=
[
scan_outs
]
##
### Step 9. Figure out which outs are update rules for shared variables
...
...
@@ -996,55 +986,57 @@ def scan( fn
##
update_map
=
Updates
()
def
remove_dimensions
(
outs
,
steps_return
,
offsets
=
None
):
def
remove_dimensions
(
outs
,
steps_return
,
offsets
=
None
):
out_ls
=
[]
for
idx
,
out
in
enumerate
(
outs
):
if
idx
in
steps_return
:
if
steps_return
[
idx
]
>
1
:
out_ls
.
append
(
out
[
-
steps_return
[
idx
]:]
)
out_ls
.
append
(
out
[
-
steps_return
[
idx
]:]
)
else
:
out_ls
.
append
(
out
[
-
1
]
)
out_ls
.
append
(
out
[
-
1
]
)
else
:
if
offsets
is
None
:
out_ls
.
append
(
out
)
out_ls
.
append
(
out
)
else
:
out_ls
.
append
(
out
[
offsets
[
idx
]:]
)
out_ls
.
append
(
out
[
offsets
[
idx
]:]
)
return
out_ls
offset
=
n_mit_mot
offsets
=
[
abs
(
numpy
.
min
(
x
))
for
x
in
mit_sot_tap_array
]
offsets
=
[
abs
(
numpy
.
min
(
x
))
for
x
in
mit_sot_tap_array
]
mit_sot_outs
=
remove_dimensions
(
scan_outs
[
offset
:
offset
+
n_mit_sot
]
,
mit_sot_return_steps
,
offsets
)
scan_outs
[
offset
:
offset
+
n_mit_sot
],
mit_sot_return_steps
,
offsets
)
offset
+=
n_mit_sot
offsets
=
[
1
for
x
in
xrange
(
n_sit_sot
)
]
offsets
=
[
1
for
x
in
xrange
(
n_sit_sot
)
]
sit_sot_outs
=
remove_dimensions
(
scan_outs
[
offset
:
offset
+
n_sit_sot
]
,
sit_sot_return_steps
,
offsets
)
scan_outs
[
offset
:
offset
+
n_sit_sot
],
sit_sot_return_steps
,
offsets
)
offset
+=
n_sit_sot
nit_sot_outs
=
remove_dimensions
(
scan_outs
[
offset
:
offset
+
n_nit_sot
]
,
nit_sot_return_steps
)
scan_outs
[
offset
:
offset
+
n_nit_sot
],
nit_sot_return_steps
)
offset
+=
n_nit_sot
for
idx
,
update_rule
in
enumerate
(
scan_outs
[
offset
:
offset
+
n_shared_outs
]):
for
idx
,
update_rule
in
enumerate
(
scan_outs
[
offset
:
offset
+
n_shared_outs
]):
update_map
[
shared_scan_inputs
[
idx
]]
=
update_rule
_scan_out_list
=
(
mit_sot_outs
+
sit_sot_outs
+
nit_sot_outs
)
_scan_out_list
=
(
mit_sot_outs
+
sit_sot_outs
+
nit_sot_outs
)
# Step 10. I need to reorder the outputs to be in the order expected by
# the user
rightOrder
=
(
mit_sot_rightOrder
+
sit_sot_rightOrder
+
nit_sot_rightOrder
)
scan_out_list
=
[
None
]
*
len
(
rightOrder
)
for
idx
,
pos
in
enumerate
(
rightOrder
):
scan_out_list
[
pos
]
=
_scan_out_list
[
idx
]
rightOrder
=
(
mit_sot_rightOrder
+
sit_sot_rightOrder
+
nit_sot_rightOrder
)
scan_out_list
=
[
None
]
*
len
(
rightOrder
)
for
idx
,
pos
in
enumerate
(
rightOrder
):
scan_out_list
[
pos
]
=
_scan_out_list
[
idx
]
if
len
(
scan_out_list
)
==
1
:
scan_out_list
=
scan_out_list
[
0
]
elif
len
(
scan_out_list
)
==
0
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论