Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5a3a1d82
提交
5a3a1d82
authored
11月 22, 2011
作者:
nouiz
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #219 from pascanur/better_pushout_optimization
Better pushout optimization
上级
518fd20d
0fe3b745
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
585 行增加
和
526 行删除
+585
-526
scan_op.py
theano/scan_module/scan_op.py
+267
-262
scan_opt.py
theano/scan_module/scan_opt.py
+284
-264
test_scan.py
theano/scan_module/tests/test_scan.py
+34
-0
没有找到文件。
theano/scan_module/scan_op.py
浏览文件 @
5a3a1d82
...
...
@@ -5,10 +5,10 @@ See scan.py for details on scan
"""
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
)
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
...
...
@@ -39,15 +39,11 @@ _logger = logging.getLogger('theano.scan_module.scan_op')
class
Scan
(
PureOp
):
#
# OLD DOCUMENTATION CAN BE FOUND NEAR REVISION 2581
#
def
__init__
(
self
,
inputs
,
outputs
,
info
,
typeConstructor
=
None
def
__init__
(
self
,
inputs
,
outputs
,
info
,
typeConstructor
=
None
,
):
"""
:param inputs: inputs of the inner function of scan
...
...
@@ -56,7 +52,7 @@ class Scan(PureOp):
the scan op.
"""
# adding properties into self
self
.
inputs
=
inputs
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
__dict__
.
update
(
info
)
# I keep a version of info in self, to use in __eq__ and __hash__,
...
...
@@ -70,15 +66,16 @@ class Scan(PureOp):
jdx
=
0
if
typeConstructor
is
None
:
typeConstructor
=
lambda
broadcastable
,
dtype
:
TensorType
(
broadcastable
=
broadcastable
,
dtype
=
dtype
)
broadcastable
=
broadcastable
,
dtype
=
dtype
)
while
idx
<
self
.
n_mit_mot_outs
:
# Not that for mit_mot there are several output slices per
# output sequence
o
=
outputs
[
idx
]
o
=
outputs
[
idx
]
self
.
output_types
.
append
(
typeConstructor
(
broadcastable
=
(
False
,)
+
o
.
type
.
broadcastable
,
dtype
=
o
.
type
.
dtype
)
typeConstructor
(
broadcastable
=
(
False
,)
+
o
.
type
.
broadcastable
,
dtype
=
o
.
type
.
dtype
)
)
idx
+=
len
(
self
.
mit_mot_out_slices
[
jdx
])
jdx
+=
1
...
...
@@ -88,32 +85,32 @@ class Scan(PureOp):
for
o
in
outputs
[
idx
:
end
]:
self
.
output_types
.
append
(
typeConstructor
(
broadcastable
=
(
False
,)
+
o
.
type
.
broadcastable
,
dtype
=
o
.
type
.
dtype
))
broadcastable
=
(
False
,)
+
o
.
type
.
broadcastable
,
dtype
=
o
.
type
.
dtype
))
# shared outputs + possibly the ending condition
for
o
in
outputs
[
end
:]:
self
.
output_types
.
append
(
o
.
type
)
self
.
output_types
.
append
(
o
.
type
)
if
self
.
as_while
:
self
.
output_types
=
self
.
output_types
[:
-
1
]
self
.
destroy_map
=
{}
if
hasattr
(
self
,
'inplace'
)
and
self
.
inplace
:
if
hasattr
(
self
,
'inplace'
)
and
self
.
inplace
:
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
):
self
.
n_sit_sot
):
self
.
destroy_map
[
idx
]
=
[
idx
+
1
+
self
.
n_seqs
]
mode_instance
=
compile
.
mode
.
get_mode
(
self
.
mode
)
# if the default mode is used, and that mode is ProfileMode
# then we need to copy the mode otherwise the time for a given
# op will be counted multiple times
if
(
self
.
mode
is
None
and
isinstance
(
mode_instance
,
compile
.
profilemode
.
ProfileMode
)
):
if
(
self
.
mode
is
None
and
isinstance
(
mode_instance
,
compile
.
profilemode
.
ProfileMode
)):
mode_instance
=
compile
.
profilemode
.
ProfileMode
(
optimizer
=
mode_instance
.
provided_optimizer
,
linker
=
mode_instance
.
provided_linker
)
compile
.
profilemode
.
prof_mode_instance_to_print
.
append
(
mode_instance
)
optimizer
=
mode_instance
.
provided_optimizer
,
linker
=
mode_instance
.
provided_linker
)
compile
.
profilemode
.
prof_mode_instance_to_print
.
append
(
mode_instance
)
self
.
mode_instance
=
mode_instance
if
self
.
name
:
self
.
mode_instance
.
message
=
self
.
name
+
" sub profile"
...
...
@@ -122,7 +119,7 @@ class Scan(PureOp):
else
:
self
.
mode_instance
=
mode_instance
if
not
hasattr
(
self
,
'name'
)
or
self
.
name
is
None
:
if
not
hasattr
(
self
,
'name'
)
or
self
.
name
is
None
:
self
.
name
=
'scan_fn'
# to have a fair __eq__ comparison later on, we update the info with
# the actual mode used to compile the function and the name of the
...
...
@@ -130,27 +127,26 @@ class Scan(PureOp):
self
.
info
[
'name'
]
=
self
.
name
# Pre-computing some values to speed up perform
self
.
mintaps
=
[
numpy
.
min
(
x
)
for
x
in
self
.
tap_array
]
self
.
mintaps
+=
[
0
for
x
in
xrange
(
self
.
n_nit_sot
)
]
self
.
seqs_arg_offset
=
1
+
self
.
n_seqs
self
.
shared_arg_offset
=
(
self
.
seqs_arg_offset
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
self
.
nit_sot_arg_offset
=
(
self
.
shared_arg_offset
+
self
.
n_shared_outs
)
self
.
mintaps
=
[
numpy
.
min
(
x
)
for
x
in
self
.
tap_array
]
self
.
mintaps
+=
[
0
for
x
in
xrange
(
self
.
n_nit_sot
)
]
self
.
seqs_arg_offset
=
1
+
self
.
n_seqs
self
.
shared_arg_offset
=
(
self
.
seqs_arg_offset
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
self
.
nit_sot_arg_offset
=
(
self
.
shared_arg_offset
+
self
.
n_shared_outs
)
self
.
n_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
self
.
n_tap_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
if
not
self
.
info
[
'gpu'
]:
tmp_in
,
tmp_out
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
self
.
outputs
)
local_env
=
gof
.
Env
(
tmp_in
,
tmp_out
)
self
.
_cmodule_key
=
gof
.
CLinker
.
cmodule_key_
(
local_env
,[])
self
.
_cmodule_key
=
gof
.
CLinker
.
cmodule_key_
(
local_env
,
[])
self
.
_hash_inner_graph
=
hash
(
self
.
_cmodule_key
)
else
:
self
.
_hash_inner_graph
=
self
.
info
[
'gpu_hash'
]
def
make_node
(
self
,
*
inputs
):
assert
numpy
.
all
(
isinstance
(
i
,
gof
.
Variable
)
for
i
in
inputs
)
# assert dtype is consistent
...
...
@@ -173,23 +169,23 @@ class Scan(PureOp):
# Flags that indicate which inputs are vectors
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
inputs
[
1
:
1
+
self
.
n_seqs
]
]
self
.
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
inputs
[
1
+
self
.
n_seqs
:
(
1
+
self
.
n_seqs
+
self
.
n_outs
)]
]
self
.
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
inputs
[
1
:
1
+
self
.
n_seqs
]
]
self
.
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
inputs
[
1
+
self
.
n_seqs
:
(
1
+
self
.
n_seqs
+
self
.
n_outs
)]]
self
.
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
# Check if input sequences and variables representing a slice of
# them have the same dtype
for
idx
in
xrange
(
self
.
n_seqs
):
if
inputs
[
1
+
idx
]
.
dtype
!=
self
.
inputs
[
idx
]
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'sequence'
,
str
(
inputs
[
1
+
idx
])
,
idx
,
inputs
[
1
+
idx
]
.
dtype
,
str
(
self
.
inputs
[
idx
])
,
self
.
inputs
[
idx
]
.
dtype
)
)
if
inputs
[
1
+
idx
]
.
dtype
!=
self
.
inputs
[
idx
]
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'sequence'
,
str
(
inputs
[
1
+
idx
]),
idx
,
inputs
[
1
+
idx
]
.
dtype
,
str
(
self
.
inputs
[
idx
]),
self
.
inputs
[
idx
]
.
dtype
)
)
# Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output
...
...
@@ -198,73 +194,73 @@ class Scan(PureOp):
# Maybe checking that ndim fits would be good as well !?
index_i
=
self
.
n_seqs
index_o
=
0
index
=
1
+
self
.
n_seqs
start
=
index
end
=
index
+
self
.
n_mit_mot
index
=
1
+
self
.
n_seqs
start
=
index
end
=
index
+
self
.
n_mit_mot
while
index
<
end
:
for
k
in
self
.
tap_array
[
index
-
start
]:
for
k
in
self
.
tap_array
[
index
-
start
]:
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
str
(
inputs
[
index
])
,
index
,
inputs
[
index
]
.
dtype
,
str
(
self
.
inputs
[
index_i
])
,
self
.
inputs
[
index_i
]
.
dtype
)
)
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
str
(
inputs
[
index
]),
index
,
inputs
[
index
]
.
dtype
,
str
(
self
.
inputs
[
index_i
]),
self
.
inputs
[
index_i
]
.
dtype
)
)
index_i
+=
1
for
k
in
self
.
mit_mot_out_slices
[
index
-
start
]:
for
k
in
self
.
mit_mot_out_slices
[
index
-
start
]:
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
])
,
index
,
inputs
[
index
]
.
dtype
,
self
.
outputs
[
index_o
]
.
dtype
)
)
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
]),
index
,
inputs
[
index
]
.
dtype
,
self
.
outputs
[
index_o
]
.
dtype
)
)
index_o
+=
1
index
+=
1
# Same checks as above but for outputs of type mit_sot and sit_sot
end
+=
self
.
n_mit_sot
+
self
.
n_sit_sot
while
index
<
end
:
for
k
in
self
.
tap_array
[
index
-
start
]:
for
k
in
self
.
tap_array
[
index
-
start
]:
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'Initial state'
,
str
(
inputs
[
index
])
,
index
,
inputs
[
index
]
.
dtype
,
str
(
self
.
inputs
[
index_i
])
,
self
.
inputs
[
index_i
]
.
dtype
)
)
raise
ValueError
(
err_msg1
%
(
'Initial state'
,
str
(
inputs
[
index
]),
index
,
inputs
[
index
]
.
dtype
,
str
(
self
.
inputs
[
index_i
]),
self
.
inputs
[
index_i
]
.
dtype
)
)
index_i
+=
1
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
])
,
index
,
inputs
[
index
]
.
dtype
,
self
.
outputs
[
index_o
]
.
dtype
)
)
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
]),
index
,
inputs
[
index
]
.
dtype
,
self
.
outputs
[
index_o
]
.
dtype
)
)
index_o
+=
1
index
+=
1
index
+=
1
# Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?!
end
+=
self
.
n_shared_outs
end
+=
self
.
n_shared_outs
index_o
+=
self
.
n_nit_sot
while
index
<
end
:
if
(
hasattr
(
inputs
[
index
],
'dtype'
)
and
if
(
hasattr
(
inputs
[
index
],
'dtype'
)
and
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
):
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
])
,
index
,
inputs
[
index
]
.
dtype
,
self
.
outputs
[
index_o
]
.
dtype
)
)
index
+=
1
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
]),
index
,
inputs
[
index
]
.
dtype
,
self
.
outputs
[
index_o
]
.
dtype
)
)
index
+=
1
index_o
+=
1
for
x
in
inputs
[
index
:
index
+
self
.
n_nit_sot
]:
for
x
in
inputs
[
index
:
index
+
self
.
n_nit_sot
]:
# For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization
if
(
str
(
x
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
or
if
(
str
(
x
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
or
x
.
ndim
!=
0
):
raise
ValueError
(
'For output
%
d you need to provide a '
'scalar int !'
,
x
)
'scalar int !'
,
x
)
apply_node
=
Apply
(
self
,
inputs
,
[
t
()
for
t
in
self
.
output_types
])
apply_node
=
Apply
(
self
,
inputs
,
[
t
()
for
t
in
self
.
output_types
])
return
apply_node
def
__eq__
(
self
,
other
):
...
...
@@ -284,7 +280,7 @@ class Scan(PureOp):
# check. Namely, do the internal graph represent same
# computations
for
self_in
,
other_in
in
zip
(
self
.
inputs
,
other
.
inputs
):
if
self_in
.
type
!=
other_in
.
type
:
if
self_in
.
type
!=
other_in
.
type
:
return
False
if
not
scan_utils
.
equal_computations
(
self
.
outputs
,
...
...
@@ -308,21 +304,19 @@ class Scan(PureOp):
else
:
name
=
'for'
if
self
.
inplace
:
aux_txt
=
'
%
s{inplace,
%
s,
%
s}'
%
(
name
,
gpu_str
,
str
(
self
.
name
))
if
self
.
inplace
:
aux_txt
=
'
%
s{inplace,
%
s,
%
s}'
%
(
name
,
gpu_str
,
str
(
self
.
name
))
else
:
aux_txt
=
'
%
s{
%
s,
%
s}'
%
(
name
,
gpu_str
,
str
(
self
.
name
))
aux_txt
=
'
%
s{
%
s,
%
s}'
%
(
name
,
gpu_str
,
str
(
self
.
name
))
return
aux_txt
def
__hash__
(
self
):
return
(
hash
(
type
(
self
))
^
return
(
hash
(
type
(
self
))
^
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self
.
_hash_inner_graph
^
scan_utils
.
hash_listsDictsTuples
(
self
.
info
)
)
scan_utils
.
hash_listsDictsTuples
(
self
.
info
))
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
"""
...
...
@@ -348,7 +342,6 @@ class Scan(PureOp):
# Setting up all my variables in what I believe is a more Cython
# friendly form
node_input_storage
=
[
storage_map
[
r
]
for
r
in
node
.
inputs
]
node_output_storage
=
[
storage_map
[
r
]
for
r
in
node
.
outputs
]
node_input_compute
=
[
compute_map
[
r
]
for
r
in
node
.
inputs
]
...
...
@@ -357,64 +350,65 @@ class Scan(PureOp):
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# scan is done
slices
=
(
self
.
n_mit_mot_outs
+
slices
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
wrapped_inputs
=
[
Param
(
x
,
borrow
=
True
)
for
x
in
self
.
inputs
]
self
.
n_nit_sot
)
wrapped_inputs
=
[
Param
(
x
,
borrow
=
True
)
for
x
in
self
.
inputs
]
wrapped_outputs
=
[
Out
(
x
,
borrow
=
True
)
for
x
in
self
.
outputs
[:
slices
]
]
self
.
outputs
[:
slices
]]
wrapped_outputs
+=
self
.
outputs
[
slices
:]
profile
=
None
if
(
theano
.
config
.
profile
or
(
isinstance
(
self
.
profile
,
(
basestring
,
bool
,
int
))
if
(
theano
.
config
.
profile
or
(
isinstance
(
self
.
profile
,
(
basestring
,
bool
,
int
))
and
self
.
profile
)):
if
isinstance
(
self
.
profile
,
basestring
):
profile
=
ScanProfileStats
(
name
=
self
.
profile
)
profile
=
ScanProfileStats
(
name
=
self
.
profile
)
else
:
profile
=
ScanProfileStats
(
name
=
self
.
name
)
profile
=
ScanProfileStats
(
name
=
self
.
name
)
elif
self
.
profile
:
profile
=
self
.
profile
self
.
fn
=
function
(
wrapped_inputs
,
wrapped_outputs
,
mode
=
self
.
mode_instance
,
name
=
self
.
name
,
profile
=
profile
)
mode
=
self
.
mode_instance
,
name
=
self
.
name
,
profile
=
profile
)
try
:
cython_mintaps
=
numpy
.
asarray
(
self
.
mintaps
,
dtype
=
'int32'
)
raise
ImportError
cython_mintaps
=
numpy
.
asarray
(
self
.
mintaps
,
dtype
=
'int32'
)
cython_tap_array_len
=
\
numpy
.
asarray
([
len
(
x
)
for
x
in
self
.
tap_array
],
numpy
.
asarray
([
len
(
x
)
for
x
in
self
.
tap_array
],
dtype
=
'int32'
)
if
len
(
self
.
tap_array
)
==
0
:
d1
=
0
else
:
d1
=
numpy
.
max
(
cython_tap_array_len
)
d0
=
len
(
self
.
tap_array
)
cython_tap_array
=
numpy
.
zeros
((
d0
,
d1
),
dtype
=
'int32'
)
cython_tap_array
=
numpy
.
zeros
((
d0
,
d1
),
dtype
=
'int32'
)
for
_d0
in
range
(
d0
):
for
_d1
in
range
(
cython_tap_array_len
[
_d0
]):
cython_tap_array
[
_d0
,
_d1
]
=
self
.
tap_array
[
_d0
][
_d1
]
cython_tap_array
[
_d0
,
_d1
]
=
self
.
tap_array
[
_d0
][
_d1
]
cython_mit_mot_out_nslices
=
\
numpy
.
asarray
([
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
],
numpy
.
asarray
([
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
],
dtype
=
'int32'
)
if
len
(
self
.
mit_mot_out_slices
)
==
0
:
d1
=
0
else
:
d1
=
numpy
.
max
(
cython_mit_mot_out_nslices
)
d0
=
len
(
self
.
mit_mot_out_slices
)
cython_mit_mot_out_slices
=
numpy
.
zeros
((
d0
,
d1
),
cython_mit_mot_out_slices
=
numpy
.
zeros
((
d0
,
d1
),
dtype
=
'int32'
)
for
_d0
in
range
(
d0
):
for
_d1
in
range
(
cython_mit_mot_out_nslices
[
_d0
]):
cython_mit_mot_out_slices
[
_d0
,
_d1
]
=
\
cython_mit_mot_out_slices
[
_d0
,
_d1
]
=
\
self
.
mit_mot_out_slices
[
_d0
][
_d1
]
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
node
.
inputs
[
1
:
1
+
self
.
n_seqs
]
]
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
node
.
inputs
[
1
+
self
.
n_seqs
:
(
1
+
self
.
n_seqs
+
self
.
n_outs
)]
]
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
node
.
inputs
[
1
:
1
+
self
.
n_seqs
]
]
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
node
.
inputs
[
1
+
self
.
n_seqs
:
(
1
+
self
.
n_seqs
+
self
.
n_outs
)]
]
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
cython_vector_seqs
=
numpy
.
asarray
(
self
.
vector_seqs
,
dtype
=
'int32'
)
...
...
@@ -448,6 +442,7 @@ class Scan(PureOp):
except
ImportError
:
p
=
self
.
execute
# default arguments are stored in the closure of `rval`
def
rval
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
):
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
for
o
in
node
.
outputs
:
...
...
@@ -463,14 +458,14 @@ class Scan(PureOp):
return
self
.
inputs
[:
self
.
n_seqs
]
def
outer_seqs
(
self
,
node
):
return
node
.
inputs
[
1
:
1
+
self
.
n_seqs
]
return
node
.
inputs
[
1
:
1
+
self
.
n_seqs
]
def
inner_mitmot
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
return
self
.
inputs
[
self
.
n_seqs
:
self
.
n_seqs
+
n_taps
]
def
outer_mitmot
(
self
,
node
):
return
node
.
inputs
[
1
+
self
.
n_seqs
:
1
+
self
.
n_seqs
+
self
.
n_mit_mot
]
return
node
.
inputs
[
1
+
self
.
n_seqs
:
1
+
self
.
n_seqs
+
self
.
n_mit_mot
]
def
inner_mitmot_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
...
...
@@ -490,80 +485,80 @@ class Scan(PureOp):
ntaps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
return
self
.
inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
self
.
n_seqs
+
ntaps_upto_sit_sot
]
return
self
.
inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
self
.
n_seqs
+
ntaps_upto_sit_sot
]
def
outer_mitsot
(
self
,
node
):
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
return
node
.
inputs
[
offset
:
offset
+
self
.
n_mit_sot
]
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
return
node
.
inputs
[
offset
:
offset
+
self
.
n_mit_sot
]
def
inner_mitsot_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
return
self
.
outputs
[
n_taps
:
n_taps
+
self
.
n_mit_sot
]
return
self
.
outputs
[
n_taps
:
n_taps
+
self
.
n_mit_sot
]
def
outer_mitsot_outs
(
self
,
node
):
return
node
.
outputs
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
return
node
.
outputs
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
mitsot_taps
(
self
):
return
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
return
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
inner_sitsot
(
self
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
return
self
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
self
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot
(
self
,
node
):
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
return
node
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot
(
self
,
node
):
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
return
node
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
inner_sitsot_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
return
self
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
self
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot_outs
(
self
,
node
):
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
return
node
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
node
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_nitsot
(
self
,
node
):
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
return
node
.
inputs
[
offset
:
offset
+
self
.
n_nit_sot
]
return
node
.
inputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_nitsot_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
return
self
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
return
self
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
outer_nitsot_outs
(
self
,
node
):
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
return
node
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
return
node
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_shared
(
self
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
return
self
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
self
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
outer_shared
(
self
,
node
):
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
return
node
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
node
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_shared_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
+
self
.
n_nit_sot
return
self
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
self
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
outer_shared_outs
(
self
,
node
):
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
return
node
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
node
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_non_seqs
(
self
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
...
...
@@ -574,12 +569,11 @@ class Scan(PureOp):
return
self
.
inputs
[
offset
:]
def
outer_non_seqs
(
self
,
node
):
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
return
node
.
inputs
[
offset
:]
def
execute
(
self
,
node
,
args
,
outs
):
def
execute
(
self
,
node
,
args
,
outs
):
"""
The args are packed like this:
...
...
@@ -607,7 +601,7 @@ class Scan(PureOp):
# negative flip sequences around, and make n_steps positive
t0_call
=
time
.
time
()
t_fn
=
0
n_steps
=
args
[
0
]
n_steps
=
args
[
0
]
seqs
=
[]
if
n_steps
<
0
:
n_steps
=
abs
(
n_steps
)
...
...
@@ -616,7 +610,7 @@ class Scan(PureOp):
raise
ValueError
((
'Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'seq.shape):'
),
n_steps
,
node
.
inputs
[
1
+
idx
],
node
.
inputs
[
1
+
idx
],
seq
.
shape
)
seqs
.
append
(
seq
[::
-
1
])
else
:
...
...
@@ -625,35 +619,37 @@ class Scan(PureOp):
raise
ValueError
((
'Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'seq.shape):'
),
n_steps
,
node
.
inputs
[
1
+
idx
],
node
.
inputs
[
1
+
idx
],
seq
.
shape
)
seqs
.
append
(
seq
)
# 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containting the length of each output
# pos -- map containing the current position of each output
# pos -- map containing the current position of each
# output
store_steps
=
[
arg
.
shape
[
0
]
for
arg
store_steps
=
[
arg
.
shape
[
0
]
for
arg
in
args
[
self
.
seqs_arg_offset
:
self
.
shared_arg_offset
]
]
store_steps
+=
[
arg
for
arg
in
self
.
shared_arg_offset
]]
store_steps
+=
[
arg
for
arg
in
args
[
self
.
nit_sot_arg_offset
:
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
]
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
]
]
pos
=
[
(
-
self
.
mintaps
[
idx
])
%
store_steps
[
idx
]
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
)]
pos
=
[
(
-
self
.
mintaps
[
idx
])
%
store_steps
[
idx
]
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
)]
# 2.1 Create storage space for outputs
for
idx
in
xrange
(
self
.
n_outs
):
if
self
.
inplace
:
# ^ Case 1. Outputs should be computed inplace of their
# initial state
outs
[
idx
][
0
]
=
args
[
self
.
seqs_arg_offset
+
idx
]
elif
(
outs
[
idx
][
0
]
is
not
None
and
outs
[
idx
][
0
]
.
shape
[
1
:]
==
args
[
self
.
seqs_arg_offset
+
idx
]
.
shape
[
1
:]
and
outs
[
idx
][
0
]
.
shape
[
0
]
>=
store_steps
[
idx
]
):
outs
[
idx
][
0
]
=
args
[
self
.
seqs_arg_offset
+
idx
]
elif
(
outs
[
idx
][
0
]
is
not
None
and
outs
[
idx
][
0
]
.
shape
[
1
:]
==
args
[
self
.
seqs_arg_offset
+
idx
]
.
shape
[
1
:]
and
outs
[
idx
][
0
]
.
shape
[
0
]
>=
store_steps
[
idx
]):
# Put in the values of the initial state
outs
[
idx
][
0
]
=
outs
[
idx
][
0
][:
store_steps
[
idx
]]
outs
[
idx
][
0
]
=
outs
[
idx
][
0
][:
store_steps
[
idx
]]
if
idx
>
self
.
n_mit_mot
:
l
=
-
self
.
mintaps
[
idx
]
outs
[
idx
][
0
][:
l
]
=
args
[
self
.
seqs_arg_offset
+
idx
][:
l
]
...
...
@@ -662,28 +658,28 @@ class Scan(PureOp):
else
:
outs
[
idx
][
0
]
=
args
[
self
.
seqs_arg_offset
+
idx
]
.
copy
()
offset
=
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
other_args
=
args
[
offset
:]
input_storage
=
self
.
fn
.
input_storage
output_storage
=
self
.
fn
.
output_storage
fn
=
self
.
fn
.
fn
offset
=
(
self
.
n_seqs
+
sum
(
map
(
len
,
self
.
tap_array
[:
self
.
n_outs
]))
+
offset
=
(
self
.
n_seqs
+
sum
(
map
(
len
,
self
.
tap_array
[:
self
.
n_outs
]))
+
self
.
n_shared_outs
)
for
idx
in
xrange
(
len
(
other_args
)):
input_storage
[
idx
+
offset
]
.
storage
[
0
]
=
other_args
[
idx
]
input_storage
[
idx
+
offset
]
.
storage
[
0
]
=
other_args
[
idx
]
i
=
0
cond
=
True
############## THE MAIN LOOP #########################
#for i in xrange(n_steps):
while
(
i
<
n_steps
)
and
cond
:
while
(
i
<
n_steps
)
and
cond
:
# sequences over which scan iterates
# 3. collect input slices
for
idx
in
xrange
(
self
.
n_seqs
):
if
self
.
vector_seqs
[
idx
]:
input_storage
[
idx
]
.
storage
[
0
]
=
seqs
[
idx
][
i
:
i
+
1
]
.
reshape
(())
input_storage
[
idx
]
.
storage
[
0
]
=
\
seqs
[
idx
][
i
:
i
+
1
]
.
reshape
(())
else
:
input_storage
[
idx
]
.
storage
[
0
]
=
seqs
[
idx
][
i
]
...
...
@@ -691,26 +687,25 @@ class Scan(PureOp):
for
idx
in
xrange
(
self
.
n_outs
):
if
self
.
vector_outs
[
idx
]:
for
tap
in
self
.
tap_array
[
idx
]:
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
input_storage
[
offset
]
.
storage
[
0
]
=
\
outs
[
idx
][
0
][
_idx
:
_idx
+
1
]
.
reshape
(())
outs
[
idx
][
0
][
_idx
:
_idx
+
1
]
.
reshape
(())
offset
+=
1
else
:
for
tap
in
self
.
tap_array
[
idx
]:
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
input_storage
[
offset
]
.
storage
[
0
]
=
outs
[
idx
][
0
][
_idx
]
offset
+=
1
a_offset
=
self
.
shared_arg_offset
o_offset
=
self
.
n_outs
+
self
.
n_nit_sot
if
i
==
0
:
for
j
in
xrange
(
self
.
n_shared_outs
):
input_storage
[
offset
]
.
storage
[
0
]
=
args
[
a_offset
+
j
]
input_storage
[
offset
]
.
storage
[
0
]
=
args
[
a_offset
+
j
]
offset
+=
1
else
:
for
j
in
xrange
(
self
.
n_shared_outs
):
input_storage
[
offset
]
.
storage
[
0
]
=
outs
[
o_offset
+
j
][
0
]
input_storage
[
offset
]
.
storage
[
0
]
=
outs
[
o_offset
+
j
][
0
]
offset
+=
1
# 4. collecting slices where the output should be stored
...
...
@@ -718,23 +713,24 @@ class Scan(PureOp):
output_storage
[
idx
]
.
storage
[
0
]
=
None
offset
=
self
.
n_mit_mot_outs
if
i
!=
0
and
self
.
n_nit_sot
>
0
:
if
i
!=
0
and
self
.
n_nit_sot
>
0
:
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
):
if
(
store_steps
[
idx
+
self
.
n_mit_mot
]
==
1
or
self
.
vector_outs
[
idx
+
self
.
n_mit_mot
]):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
if
(
store_steps
[
idx
+
self
.
n_mit_mot
]
==
1
or
self
.
vector_outs
[
idx
+
self
.
n_mit_mot
]):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
else
:
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
\
outs
[
idx
+
self
.
n_mit_mot
][
0
][
pos
[
idx
+
self
.
n_mit_mot
]]
_pos0
=
idx
+
self
.
n_mit_mot
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
\
outs
[
_pos0
][
0
][
pos
[
_pos0
]]
else
:
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
offset
+=
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
offset
+=
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
for
idx
in
xrange
(
self
.
n_shared_outs
):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
# If condition add it to the mix
if
self
.
as_while
:
pdx
=
offset
+
self
.
n_shared_outs
...
...
@@ -762,97 +758,102 @@ class Scan(PureOp):
# 5.1 Copy over the values for mit_mot outputs
for
j
in
xrange
(
self
.
n_mit_mot
):
for
k
in
self
.
mit_mot_out_slices
[
j
]:
outs
[
j
][
0
][
k
+
pos
[
j
]]
=
output_storage
[
offset_out
]
.
storage
[
0
]
outs
[
j
][
0
][
k
+
pos
[
j
]]
=
\
output_storage
[
offset_out
]
.
storage
[
0
]
offset_out
+=
1
# 5.2 Copy over the values for mit_sot/sit_sot outputs
begin
=
self
.
n_mit_mot
end
=
self
.
n_outs
end
=
self
.
n_outs
offset_out
-=
self
.
n_mit_mot
for
j
in
xrange
(
begin
,
end
):
if
(
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]
or
outs
[
j
][
0
][
pos
[
j
]]
is
not
output_storage
[
offset_out
+
j
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
offset_out
+
j
]
.
storage
[
0
]
if
(
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]
or
outs
[
j
][
0
][
pos
[
j
]]
is
not
output_storage
[
offset_out
+
j
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
=
\
output_storage
[
offset_out
+
j
]
.
storage
[
0
]
# 5.3 Copy over the values for nit_sot outputs
begin
=
end
end
+=
self
.
n_nit_sot
for
j
in
xrange
(
begin
,
end
):
begin
=
end
end
+=
self
.
n_nit_sot
for
j
in
xrange
(
begin
,
end
):
if
i
==
0
:
jout
=
j
+
offset_out
shape
=
(
store_steps
[
j
],)
+
output_storage
[
jout
]
.
storage
[
0
]
.
shape
jout
=
j
+
offset_out
shape
=
(
store_steps
[
j
],)
+
\
output_storage
[
jout
]
.
storage
[
0
]
.
shape
if
len
(
output_storage
[
jout
]
.
storage
[
0
]
.
shape
)
==
0
:
self
.
vector_outs
[
j
]
=
True
dtype
=
output_storage
[
jout
]
.
storage
[
0
]
.
dtype
if
(
outs
[
j
][
0
]
is
None
or
outs
[
j
][
0
]
.
shape
[
0
]
<
store_steps
[
j
]
or
outs
[
j
][
0
]
.
shape
[
1
:]
!=
shape
[
1
:]
or
outs
[
j
][
0
]
.
dtype
!=
dtype
):
outs
[
j
][
0
]
.
dtype
!=
dtype
):
if
self
.
gpu
:
outs
[
j
][
0
]
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
(
shape
)
_cuda
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
outs
[
j
][
0
]
=
_cuda
.
zeros
(
shape
)
else
:
outs
[
j
][
0
]
=
numpy
.
zeros
(
shape
,
dtype
)
elif
outs
[
j
][
0
]
.
shape
[
0
]
!=
store_steps
[
j
]:
outs
[
j
][
0
]
=
outs
[
j
][
0
][:
store_steps
[
j
]]
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
jout
]
.
storage
[
0
]
elif
(
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]
or
outs
[
j
][
0
][
pos
[
j
]]
is
not
output_storage
[
j
+
offset_out
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
j
+
offset_out
]
.
storage
[
0
]
outs
[
j
][
0
][
pos
[
j
]]
is
not
output_storage
[
j
+
offset_out
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
=
\
output_storage
[
j
+
offset_out
]
.
storage
[
0
]
# 5.4 Copy over the values for outputs corresponding to shared
# variables
begin
=
end
end
+=
self
.
n_shared_outs
for
j
in
xrange
(
begin
,
end
):
jout
=
j
+
offset_out
begin
=
end
end
+=
self
.
n_shared_outs
for
j
in
xrange
(
begin
,
end
):
jout
=
j
+
offset_out
outs
[
j
][
0
]
=
output_storage
[
jout
]
.
storage
[
0
]
pos
=
[
(
idx
+
1
)
%
store
for
idx
,
store
in
itertools
.
izip
(
pos
,
store_steps
)
]
i
=
i
+
1
pos
=
[(
idx
+
1
)
%
store
for
idx
,
store
in
itertools
.
izip
(
pos
,
store_steps
)]
i
=
i
+
1
# 6. Check if you need to re-order output buffers
begin
=
self
.
n_mit_mot
end
=
self
.
n_outs
+
self
.
n_nit_sot
end
=
self
.
n_outs
+
self
.
n_nit_sot
for
idx
in
xrange
(
begin
,
end
):
min_tap
=
self
.
mintaps
[
idx
]
if
(
store_steps
[
idx
]
<
i
-
self
.
mintaps
[
idx
]
and
pos
[
idx
]
<
store_steps
[
idx
]
):
if
(
store_steps
[
idx
]
<
i
-
self
.
mintaps
[
idx
]
and
pos
[
idx
]
<
store_steps
[
idx
]):
pdx
=
pos
[
idx
]
if
pdx
<
store_steps
[
idx
]
//
2
:
shape
=
(
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
if
pdx
<
store_steps
[
idx
]
//
2
:
shape
=
(
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
cuda
.
CudaNdarray
):
tmp
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
(
shape
)
_cuda
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
tmp
=
_cuda
.
zeros
(
shape
)
else
:
tmp
=
numpy
.
empty
(
shape
)
tmp
[:]
=
outs
[
idx
][
0
][:
pdx
]
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
outs
[
idx
][
0
][
pdx
:]
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
tmp
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
outs
[
idx
][
0
][
pdx
:]
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
tmp
del
tmp
else
:
shape
=
(
store_steps
[
idx
]
-
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
shape
=
(
store_steps
[
idx
]
-
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
cuda
.
CudaNdarray
):
tmp
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
(
shape
)
_cuda
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
tmp
=
_cuda
.
zeros
(
shape
)
else
:
tmp
=
numpy
.
empty
(
shape
)
tmp
[:]
=
outs
[
idx
][
0
][
pdx
:]
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
outs
[
idx
][
0
][:
pdx
]
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
tmp
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
outs
[
idx
][
0
][:
pdx
]
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
tmp
del
tmp
# This would normally happen only when doing truncated
# backpropagation through time. In such a scenarion Scan is
# expected to return 0 for all entries for which the gradient is
# not actually computed
elif
store_steps
[
idx
]
>
i
-
self
.
mintaps
[
idx
]:
outs
[
idx
][
0
][
i
-
self
.
mintaps
[
idx
]:]
=
0
outs
[
idx
][
0
][
i
-
self
.
mintaps
[
idx
]:]
=
0
# This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output
# to have that length ( and not the maximal length possible)
...
...
@@ -883,7 +884,7 @@ class Scan(PureOp):
profile
.
callcount
+=
1
profile
.
nbsteps
+=
n_steps
profile
.
call_time
+=
t_call
profile
.
vm_call_time
+=
t_fn
profile
.
vm_call_time
+=
t_fn
if
hasattr
(
self
.
fn
.
fn
,
'update_profile'
):
self
.
fn
.
fn
.
update_profile
(
profile
)
...
...
@@ -896,7 +897,7 @@ class Scan(PureOp):
#self.fn.maker.mode.fn_time += t_fn
# Old Profile Mode */
self
.
t_call
=
t_call
self
.
t_fn
=
t_fn
self
.
t_fn
=
t_fn
### Infer Shape
def
infer_shape
(
self
,
node
,
input_shapes
):
...
...
@@ -905,26 +906,27 @@ class Scan(PureOp):
# is the shape of self.inputs[i]
# sequences
seqs_shape
=
[
x
[
1
:]
for
x
in
input_shapes
[
1
:
1
+
self
.
n_seqs
]
]
seqs_shape
=
[
x
[
1
:]
for
x
in
input_shapes
[
1
:
1
+
self
.
n_seqs
]
]
# mit_mot, mit_sot, sit_sot
n_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
outs_shape
=
[]
for
idx
in
xrange
(
n_outs
):
for
k
in
self
.
tap_array
[
idx
]:
outs_shape
+=
[
input_shapes
[
idx
+
self
.
n_seqs
+
1
][
1
:]
]
outs_shape
+=
[
input_shapes
[
idx
+
self
.
n_seqs
+
1
][
1
:]
]
# shared_outs
offset
=
1
+
self
.
n_seqs
+
n_outs
for
idx
in
xrange
(
self
.
n_shared_outs
):
outs_shape
+=
[
input_shapes
[
idx
+
offset
]
]
outs_shape
+=
[
input_shapes
[
idx
+
offset
]
]
# non_sequences
offset
+=
self
.
n_nit_sot
+
self
.
n_shared_outs
inner_ins_shapes
=
seqs_shape
+
outs_shape
+
input_shapes
[
offset
:]
assert
len
(
inner_ins_shapes
)
==
len
(
self
.
inputs
)
# Non-sequences have a direct equivalent from self.inputs in node.inputs
# Non-sequences have a direct equivalent from self.inputs in
# node.inputs
inner_non_sequences
=
self
.
inputs
[
len
(
seqs_shape
)
+
len
(
outs_shape
):]
out_equivalent
=
{}
for
in_ns
,
out_ns
in
zip
(
inner_non_sequences
,
node
.
inputs
[
offset
:]):
...
...
@@ -934,22 +936,22 @@ class Scan(PureOp):
else
:
self_outs
=
self
.
outputs
outs_shape
=
scan_utils
.
infer_shape
(
outs
=
self_outs
,
inputs
=
self
.
inputs
,
input_shapes
=
inner_ins_shapes
)
outs
=
self_outs
,
inputs
=
self
.
inputs
,
input_shapes
=
inner_ins_shapes
)
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs.
# The shapes of node.inputs are valid.
validator
=
scan_utils
.
Validator
(
valid
=
input_shapes
,
invalid
=
self
.
inputs
,
valid_equivalent
=
out_equivalent
)
valid
=
input_shapes
,
invalid
=
self
.
inputs
,
valid_equivalent
=
out_equivalent
)
offset
=
1
+
self
.
n_seqs
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
offset
+=
n_outs
for
x
in
xrange
(
self
.
n_nit_sot
):
out_shape_x
=
outs_shape
[
n_outs
+
x
]
out_shape_x
=
outs_shape
[
n_outs
+
x
]
if
out_shape_x
is
None
:
# This output is not a tensor, and has no shape
scan_outs
.
append
(
None
)
...
...
@@ -957,10 +959,10 @@ class Scan(PureOp):
# We need to make sure that we can compute the shapes from
# node.inputs, and constants, without using the variables
# in the inner function.
r
=
node
.
outputs
[
n_outs
+
x
]
r
=
node
.
outputs
[
n_outs
+
x
]
assert
r
.
ndim
==
1
+
len
(
out_shape_x
)
shp
=
[
node
.
inputs
[
offset
+
self
.
n_shared_outs
+
x
]]
for
i
,
shp_i
in
zip
(
xrange
(
1
,
r
.
ndim
),
out_shape_x
):
shp
=
[
node
.
inputs
[
offset
+
self
.
n_shared_outs
+
x
]]
for
i
,
shp_i
in
zip
(
xrange
(
1
,
r
.
ndim
),
out_shape_x
):
# Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates
# whether variable is shp_i (if True), or an valid
...
...
@@ -976,34 +978,32 @@ class Scan(PureOp):
shp
.
append
(
v_shp_i
[
0
])
scan_outs
.
append
(
tuple
(
shp
))
scan_outs
+=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
self
.
n_shared_outs
]
]
scan_outs
+=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
self
.
n_shared_outs
]
]
return
scan_outs
### GRAD FUNCTION
def
grad
(
self
,
args
,
g_outs
):
# 1. forward pass - get the outputs after applying scan
scan_outputs
=
self
(
*
args
)
# 2. make sure they are given as a list
if
not
(
type
(
scan_outputs
)
in
(
list
,
tuple
)):
if
not
(
type
(
scan_outputs
)
in
(
list
,
tuple
)):
scan_outputs
=
[
scan_outputs
]
# 3. un-group / unzip the inputs
# Note ! We don't want to use the actual same variable as the ones
# used by the original scan, rather create clones of them
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
self
.
outputs
,
'_grad'
)
self_inputs
=
rval
[
0
]
self
.
outputs
,
'_grad'
)
self_inputs
=
rval
[
0
]
self_outputs
=
rval
[
1
]
seqs
=
self_inputs
[:
self
.
n_seqs
]
seqs
=
self_inputs
[:
self
.
n_seqs
]
offset
=
self
.
n_seqs
n_ins_mit_mot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
)
])
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
offset
=
self
.
n_seqs
n_ins_mit_mot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
)])
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
offset
+=
n_ins_mit_mot
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
...
...
@@ -1082,6 +1082,11 @@ class Scan(PureOp):
# 7.3. compute gradients of the inputs given one output
for
dx
,
out
in
enumerate
(
clean_outputs
):
inner_g_out
=
safe_new
(
out
)
###
#### I need to clip the gradient HERE !!
if
g_outs_no_shared
[
dx
]:
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
else
:
...
...
theano/scan_module/scan_opt.py
浏览文件 @
5a3a1d82
...
...
@@ -4,11 +4,11 @@ This module provides optimizations for scan
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron "
)
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
...
...
@@ -32,16 +32,20 @@ from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
# Logging function for sending warning or info
_logger
=
logging
.
getLogger
(
'theano.scan_module.scan_opt'
)
list_opt_slice
=
[
tensor
.
opt
.
local_abs_merge
,
tensor
.
opt
.
local_mul_switch_sink
,
tensor
.
opt
.
local_upcast_elemwise_constant_inputs
,
tensor
.
opt
.
local_remove_switch_const_cond
,
tensor
.
opt
.
constant_folding
]
list_opt_slice
=
[
tensor
.
opt
.
local_abs_merge
,
tensor
.
opt
.
local_mul_switch_sink
,
tensor
.
opt
.
local_upcast_elemwise_constant_inputs
,
tensor
.
opt
.
local_remove_switch_const_cond
,
tensor
.
opt
.
constant_folding
]
def
warning
(
*
msg
):
_logger
.
warning
(
'WARNING theano.scan: '
+
' '
.
join
(
msg
))
_logger
.
warning
(
'WARNING theano.scan: '
+
' '
.
join
(
msg
))
def
info
(
*
msg
):
_logger
.
info
(
'INFO theano.scan: '
+
' '
.
join
(
msg
))
_logger
.
info
(
'INFO theano.scan: '
+
' '
.
join
(
msg
))
@gof.local_optimizer
([
None
])
def
remove_constants_and_unused_inputs_scan
(
node
):
...
...
@@ -58,9 +62,9 @@ def remove_constants_and_unused_inputs_scan(node):
return
False
op
=
node
.
op
# We only need to take care of sequences and other arguments
st
=
op
.
n_seqs
st
=
op
.
n_seqs
st
+=
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]
]))
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]
]))
st
+=
op
.
n_sit_sot
st
+=
op
.
n_shared_outs
op_ins
,
op_outs
=
scan_utils
.
reconstruct_graph
(
op
.
inputs
,
op
.
outputs
)
...
...
@@ -70,17 +74,17 @@ def remove_constants_and_unused_inputs_scan(node):
out_stuff_inner
=
op_ins
[
op
.
n_seqs
:
st
]
non_seqs
=
op_ins
[
st
:]
st
=
(
op
.
n_seqs
+
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
+
op
.
n_shared_outs
+
1
)
st
=
(
op
.
n_seqs
+
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
+
op
.
n_shared_outs
+
1
)
outer_non_seqs
=
node
.
inputs
[
st
:]
out_stuff_outer
=
node
.
inputs
[
1
+
op
.
n_seqs
:
st
]
out_stuff_outer
=
node
.
inputs
[
1
+
op
.
n_seqs
:
st
]
# To replace constants in the outer graph by clones in the inner graph
givens
=
{}
givens
=
{}
# All the inputs of the inner graph of the new scan
nw_inner
=
[]
# Same for the outer graph, initialized w/ number of steps
...
...
@@ -88,18 +92,18 @@ def remove_constants_and_unused_inputs_scan(node):
all_ins
=
gof
.
graph
.
inputs
(
op_outs
)
for
idx
in
xrange
(
op
.
n_seqs
):
if
(
isinstance
(
node
.
inputs
[
idx
+
1
],
tensor
.
TensorConstant
)
and
node
.
inputs
[
idx
+
1
]
.
tag
.
unique_value
is
not
None
):
if
(
isinstance
(
node
.
inputs
[
idx
+
1
],
tensor
.
TensorConstant
)
and
node
.
inputs
[
idx
+
1
]
.
tag
.
unique_value
is
not
None
):
try
:
# This works if input is a constant that has all entries
# equal
val
=
tensor
.
get_constant_value
(
node
.
inputs
[
idx
+
1
])
givens
[
op_ins
[
idx
]]
=
node
.
inputs
[
idx
+
1
]
.
clone
()[
0
]
val
=
tensor
.
get_constant_value
(
node
.
inputs
[
idx
+
1
])
givens
[
op_ins
[
idx
]]
=
node
.
inputs
[
idx
+
1
]
.
clone
()[
0
]
except
TypeError
:
pass
elif
op_ins
[
idx
]
in
all_ins
:
nw_inner
+=
[
op_ins
[
idx
]]
nw_outer
+=
[
node
.
inputs
[
idx
+
1
]]
nw_outer
+=
[
node
.
inputs
[
idx
+
1
]]
nw_n_seqs
=
len
(
nw_inner
)
# Add outputs stuff
...
...
@@ -114,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(node):
nw_outer
+=
[
nw_out
]
if
len
(
nw_inner
)
!=
len
(
op_ins
):
op_outs
=
scan_utils
.
clone
(
op_outs
,
replace
=
givens
)
op_outs
=
scan_utils
.
clone
(
op_outs
,
replace
=
givens
)
nw_info
=
op
.
info
.
copy
()
nw_info
[
'n_seqs'
]
=
nw_n_seqs
# DEBUG CHECK
...
...
@@ -128,11 +132,12 @@ scan_seqopt = theano.gof.SequenceDB()
optdb
.
register
(
'scan_seqopt'
,
scan_seqopt
,
1.9
,
'fast_run'
,
'scan'
)
scan_seqopt
.
register
(
'scanOp_remove_constants_and_unused_inputs'
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
5
,
'fast_run'
,
'scan'
)
# This is a global opt for historical reason
# It should be possible to change it to a local opt.
class
PushOutNonSeqScan
(
gof
.
Optimizer
):
...
...
@@ -140,10 +145,9 @@ class PushOutNonSeqScan(gof.Optimizer):
def
__init__
(
self
):
gof
.
Optimizer
.
__init__
(
self
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
())
def
apply
(
self
,
env
):
nodelist
=
[
x
for
x
in
env
.
toposort
()
if
isinstance
(
x
.
op
,
scan_op
.
Scan
)]
...
...
@@ -152,34 +156,31 @@ class PushOutNonSeqScan(gof.Optimizer):
def
process_node
(
self
,
env
,
node
):
# this flag tells if there was any change during the last iterations
changed
=
True
changed
=
True
clean_inputs
,
clean_outputs
=
scan_utils
.
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
local_env
=
gof
.
Env
(
clean_inputs
,
clean_outputs
)
max_iterations
=
2
*
len
(
local_env
.
toposort
())
+
3
max_iterations
=
2
*
len
(
local_env
.
toposort
())
+
3
counts
=
0
to_remove
=
[]
to_replace
=
[]
replace_with_in
=
[]
to_remove
=
[]
to_replace
=
[]
replace_with_in
=
[]
replace_with_out
=
[]
op
=
node
.
op
# Construct the list of non_sequences to simplify a few things
st
=
op
.
n_seqs
st
=
op
.
n_seqs
st
+=
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]
]))
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]
]))
st
+=
op
.
n_sit_sot
st
+=
op
.
n_shared_outs
non_seqs
=
clean_inputs
[
st
:]
st
=
(
op
.
n_seqs
+
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
+
op
.
n_shared_outs
+
1
)
st
=
(
op
.
n_seqs
+
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
+
op
.
n_shared_outs
+
1
)
outer_non_seqs
=
node
.
inputs
[
st
:]
assert
len
(
non_seqs
)
==
len
(
outer_non_seqs
)
while
changed
and
counts
<
max_iterations
:
...
...
@@ -187,15 +188,15 @@ class PushOutNonSeqScan(gof.Optimizer):
changed
=
False
for
nd
in
local_env
.
toposort
():
if
(
numpy
.
all
([
(
x
in
non_seqs
)
or
(
x
.
owner
in
to_remove
)
or
isinstance
(
x
,
tensor
.
Constant
)
for
x
in
nd
.
inputs
])
and
if
(
numpy
.
all
([
(
x
in
non_seqs
)
or
(
x
.
owner
in
to_remove
)
or
isinstance
(
x
,
tensor
.
Constant
)
for
x
in
nd
.
inputs
])
and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
not
isinstance
(
nd
.
op
,
theano
.
compile
.
ViewOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
DeepCopyOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
ViewOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
DeepCopyOp
)
and
# and we didn't already looked at this node
not
nd
in
to_remove
):
...
...
@@ -206,49 +207,50 @@ class PushOutNonSeqScan(gof.Optimizer):
outside_ins
=
[]
for
x
in
nd
.
inputs
:
if
x
in
non_seqs
:
outside_ins
+=
[
outer_non_seqs
[
non_seqs
.
index
(
x
)]]
outside_ins
+=
[
outer_non_seqs
[
non_seqs
.
index
(
x
)]]
elif
x
in
to_replace
:
outside_ins
+=
[
replace_with_out
[
to_replace
.
index
(
x
)]]
outside_ins
+=
[
replace_with_out
[
to_replace
.
index
(
x
)]]
elif
isinstance
(
x
,
theano
.
Constant
):
outside_ins
+=
[
x
.
clone
()]
outside_ins
+=
[
x
.
clone
()]
else
:
raise
Exception
(
(
'Error in the `scan_pushout_non_seq_operations`'
'. The optimization tries to move some '
'computation fron scan which is not allowed '
'to move. Report this on theano-users list'
),
x
)
(
'Error in the `scan_pushout_non_seq_'
'operations`. The optimization tries '
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'
),
x
)
nw_outer_node
=
nd
.
op
.
make_node
(
*
outside_ins
)
# Step 2. Create variables for replacements
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
y_place_holder
=
scan_utils
.
safe_new
(
y
,
'_replace'
)
to_replace
+=
[
y
]
replace_with_in
+=
[
y_place_holder
]
y_place_holder
=
scan_utils
.
safe_new
(
y
,
'_replace'
)
to_replace
+=
[
y
]
replace_with_in
+=
[
y_place_holder
]
assert
type
(
y
)
==
type
(
nw_outer_node
.
outputs
[
idx
])
replace_with_out
+=
[
nw_outer_node
.
outputs
[
idx
]]
changed
=
True
if
counts
>=
max_iterations
:
raise
Exception
(
(
'Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!'
)
)
raise
Exception
(
'Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!'
)
# We need to check all candidate replacements and choose those that
# make sense for us
# Step 1. which elements of `to_replace` are used by remaining
# components of the inner function
clean_to_replace
=
[]
clean_replace_with_in
=
[]
clean_to_replace
=
[]
clean_replace_with_in
=
[]
clean_replace_with_out
=
[]
existent_nodes
=
[
nd
for
nd
in
local_env
.
toposort
()
existent_nodes
=
[
nd
for
nd
in
local_env
.
toposort
()
if
nd
not
in
to_remove
]
to_keep
=
[]
for
nd
in
existent_nodes
:
to_keep
+=
nd
.
inputs
for
idx
,
out
in
enumerate
(
to_replace
):
for
idx
,
out
in
enumerate
(
to_replace
):
if
out
in
to_keep
and
out
.
owner
not
in
existent_nodes
:
clean_to_replace
+=
[
out
]
clean_replace_with_in
+=
[
replace_with_in
[
idx
]]
clean_replace_with_in
+=
[
replace_with_in
[
idx
]]
clean_replace_with_out
+=
[
replace_with_out
[
idx
]]
if
len
(
clean_to_replace
)
>
0
:
...
...
@@ -256,7 +258,7 @@ class PushOutNonSeqScan(gof.Optimizer):
givens
=
{}
nw_outer
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
clean_replace_with_in
,
clean_replace_with_out
):
if
isinstance
(
repl_out
,
theano
.
Constant
):
...
...
@@ -274,8 +276,24 @@ class PushOutNonSeqScan(gof.Optimizer):
nwScan
=
scan_op
.
Scan
(
op_ins
,
op_outs
,
op
.
info
)
nw_node
=
nwScan
.
make_node
(
*
(
node
.
inputs
+
nw_outer
))
env
.
replace_all_validate
(
zip
(
node
.
outputs
,
nw_node
.
outputs
),
reason
=
'scan_push_computation_out'
)
reason
=
'scan_push_computation_out'
)
return
True
elif
to_keep
==
[]:
# Nothing in the inner graph should be kept
replace_with
=
{}
for
idx
,
out
in
enumerate
(
to_replace
):
if
out
in
local_env
.
outputs
:
x
=
node
.
outputs
[
local_env
.
outputs
.
index
(
out
)]
y
=
replace_with_out
[
idx
]
shape
=
[
y
.
shape
[
idx
]
for
idx
in
xrange
(
y
.
ndim
)]
replace_with
[
x
]
=
tensor
.
alloc
(
y
,
node
.
inputs
[
0
],
*
shape
)
# We need to add one extra dimension to the outputs
env
.
replace_all_validate
(
replace_with
.
items
(),
reason
=
'scan_push_computation_out'
)
else
:
return
False
...
...
@@ -290,17 +308,17 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops',
@gof.local_optimizer
([
None
])
def
scan_make_inplace
(
node
):
op
=
node
.
op
if
(
isinstance
(
op
,
scan_op
.
Scan
)
and
if
(
isinstance
(
op
,
scan_op
.
Scan
)
and
(
not
op
.
info
[
'inplace'
])
and
(
not
op
.
info
[
'gpu'
])):
info
=
op
.
info
.
copy
()
info
[
'inplace'
]
=
True
# inputs corresponding to sequences and n_steps
ls_begin
=
node
.
inputs
[:
1
+
op
.
n_seqs
]
ls
=
op
.
outer_mitmot
(
node
)
ls_begin
=
node
.
inputs
[:
1
+
op
.
n_seqs
]
ls
=
op
.
outer_mitmot
(
node
)
ls
+=
op
.
outer_mitsot
(
node
)
ls
+=
op
.
outer_sitsot
(
node
)
ls_end
=
op
.
outer_shared
(
node
)
ls_end
=
op
.
outer_shared
(
node
)
ls_end
+=
op
.
outer_nitsot
(
node
)
ls_end
+=
op
.
outer_non_seqs
(
node
)
n_outs
=
len
(
ls
)
...
...
@@ -309,19 +327,18 @@ def scan_make_inplace(node):
ls
[
idx
]
=
deep_copy_op
(
ls
[
idx
])
inputs
=
ls_begin
+
ls
+
ls_end
new_op
=
scan_op
.
Scan
(
op
.
inputs
,
op
.
outputs
,
info
)
new_op
=
scan_op
.
Scan
(
op
.
inputs
,
op
.
outputs
,
info
)
return
new_op
.
make_node
(
*
inputs
)
.
outputs
return
False
optdb
.
register
(
'scanOp_make_inplace'
,
opt
.
in2out
(
scan_make_inplace
,
ignore_newtrees
=
True
)
,
75
,
'fast_run'
,
'inplace'
,
'scan'
)
optdb
.
register
(
'scanOp_make_inplace'
,
opt
.
in2out
(
scan_make_inplace
,
ignore_newtrees
=
True
),
75
,
'fast_run'
,
'inplace'
,
'scan'
)
class
ScanSaveMem
(
gof
.
Optimizer
):
...
...
@@ -329,24 +346,25 @@ class ScanSaveMem(gof.Optimizer):
def
__init__
(
self
):
gof
.
Optimizer
.
__init__
(
self
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
())
def
process_node
(
self
,
env
,
node
):
# helpful functions
def
select_min
(
x
,
y
):
def
select_min
(
x
,
y
):
if
x
is
None
:
return
y
if
y
is
None
:
return
x
return
tensor
.
minimum
(
x
,
y
)
def
select_max
(
x
,
y
):
return
tensor
.
minimum
(
x
,
y
)
def
select_max
(
x
,
y
):
if
x
is
None
:
return
y
if
y
is
None
:
return
x
return
tensor
.
maximum
(
x
,
y
)
return
tensor
.
maximum
(
x
,
y
)
def
sanitize
(
x
):
if
x
is
None
:
...
...
@@ -367,9 +385,9 @@ class ScanSaveMem(gof.Optimizer):
op
=
node
.
op
c_outs
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
init_l
=
[
0
for
x
in
xrange
(
op
.
n_mit_mot
)]
init_l
+=
[
abs
(
numpy
.
min
(
v
))
for
v
in
op
.
tap_array
[
op
.
n_mit_mot
:]
]
init_l
+=
[
0
for
x
in
xrange
(
op
.
n_nit_sot
)]
init_l
=
[
0
for
x
in
xrange
(
op
.
n_mit_mot
)]
init_l
+=
[
abs
(
numpy
.
min
(
v
))
for
v
in
op
.
tap_array
[
op
.
n_mit_mot
:]
]
init_l
+=
[
0
for
x
in
xrange
(
op
.
n_nit_sot
)]
# 2. Check the clients of each output and see for how many steps
# does scan need to run
...
...
@@ -392,13 +410,13 @@ class ScanSaveMem(gof.Optimizer):
# change the number of steps in that case. To do this we set
# global_nsteps to None which is seen as a flag that nothing needs
# to be done
if
len
(
node
.
outputs
)
<=
c_outs
:
global_nsteps
=
{
'real'
:
-
1
,
'sym'
:
[]}
if
len
(
node
.
outputs
)
<=
c_outs
:
global_nsteps
=
{
'real'
:
-
1
,
'sym'
:
[]}
else
:
global_nsteps
=
None
# Keeps track of the original slices that each client represent
slices
=
[
None
for
o
in
node
.
outputs
]
slices
=
[
None
for
o
in
node
.
outputs
]
# A list for each output indicating how many intermediate values
# should be stored. If negative it means none of the intermediate
...
...
@@ -409,31 +427,31 @@ class ScanSaveMem(gof.Optimizer):
# Note that for mit_mot outputs and shared outputs we can not change
# the number of intermediate steps stored without affecting the
# result of the op
store_steps
=
[
0
for
o
in
xrange
(
op
.
n_mit_mot
)]
store_steps
=
[
0
for
o
in
xrange
(
op
.
n_mit_mot
)]
store_steps
+=
[
-
1
for
o
in
node
.
outputs
[
op
.
n_mit_mot
:
c_outs
]]
# Flag that says if an input has changed and we need to do something
# or not
flag_store
=
False
# 2.2 Loop over the clients
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
# look at all its clients
slices
[
i
]
=
[]
for
cl
,
_
in
out
.
clients
:
for
cl
,
_
in
out
.
clients
:
# 2.1 outputs of the function
#=> output needs all its intermediate values
if
type
(
cl
)
==
str
:
# if the node is actually an output, then
# we need to store the entire thing
global_nsteps
=
None
slices
[
i
]
=
None
global_nsteps
=
None
slices
[
i
]
=
None
break
# 2.2 non-subtensor nodes
#=> output needs all its intermediate values
elif
not
isinstance
(
cl
.
op
,
tensor
.
basic
.
Subtensor
):
global_nsteps
=
None
slices
[
i
]
=
None
global_nsteps
=
None
slices
[
i
]
=
None
break
# 2.3 subtensor nodes
#=> output might need to store just a subset of its values
...
...
@@ -444,13 +462,11 @@ class ScanSaveMem(gof.Optimizer):
if
this_slice
==
None
:
# if unable to extract idx_list
#=> outputs needs all its intermediate values
global_nsteps
=
None
slices
[
i
]
=
None
global_nsteps
=
None
slices
[
i
]
=
None
break
# 2.3.2 extract the begin/end of the first dimension
if
i
>
op
.
n_mit_mot
:
try
:
length
=
shape_of
[
out
][
0
]
...
...
@@ -463,26 +479,27 @@ class ScanSaveMem(gof.Optimizer):
length
=
out
.
shape
[
0
]
cf_slice
=
tensor
.
basic
.
get_canonical_form_slice
(
this_slice
[
0
],
length
)
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
this_slice
[
0
]
.
stop
is
None
):
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
this_slice
[
0
]
.
stop
is
None
):
global_nsteps
=
None
break
if
isinstance
(
cf_slice
[
0
],
slice
):
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
stop
)
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
stop
)
else
:
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
])
+
1
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
])
+
1
if
stop
==
sys
.
maxint
or
stop
==
length
:
stop
=
None
else
:
# there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output as
# well. Which means that if have a initial state of
# length 3, and you look for 5 steps you get an output y
# of length 8. If you only use y[:5], this does not mean
# that you only need to loop for 5 steps but actually
# only for 2 steps ( the first 3 are the initial state)
# array that contains the initial state of the output
# as well. Which means that if have a initial state of
# length 3, and you look for 5 steps you get an output
# y of length 8. If you only use y[:5], this does not
# mean that you only need to loop for 5 steps but
# actually only for 2 steps ( the first 3 are the
# initial state)
stop
=
stop
-
init_l
[
i
]
# 2.3.3 we might get away with less number of steps
...
...
@@ -494,10 +511,11 @@ class ScanSaveMem(gof.Optimizer):
elif
(
type
(
stop
)
is
int
and
stop
==
sys
.
maxint
):
global_nsteps
=
None
# yes if it is a int k, 0 < k < maxint
elif
(
type
(
stop
)
is
int
and
global_nsteps
[
'real'
]
<
stop
):
elif
(
type
(
stop
)
is
int
and
global_nsteps
[
'real'
]
<
stop
):
global_nsteps
[
'real'
]
=
stop
# yes if it is a int k, 0 < k < maxint
elif
(
type
(
stop
)
is
int
and
stop
>
0
):
elif
(
type
(
stop
)
is
int
and
stop
>
0
):
pass
# not otherwise
else
:
...
...
@@ -510,10 +528,10 @@ class ScanSaveMem(gof.Optimizer):
# there are some symbolic tensors that limit the number of
# steps
if
len
(
global_nsteps
[
'sym'
])
==
0
:
if
len
(
global_nsteps
[
'sym'
])
==
0
:
sym_steps
=
None
else
:
sym_steps
=
global_nsteps
[
'sym'
][
0
]
sym_steps
=
global_nsteps
[
'sym'
][
0
]
for
c
in
global_nsteps
[
'sym'
][
1
:]:
sym_steps
=
tensor
.
maximum
(
sym_steps
,
c
)
...
...
@@ -527,12 +545,11 @@ class ScanSaveMem(gof.Optimizer):
nw_steps
=
node
.
inputs
[
0
]
global_nsteps
=
None
# 2.4 Loop over the clients again now looking just to see how many
# intermediate steps to store
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
# look at all its clients
for
cl
,
_
in
out
.
clients
:
for
cl
,
_
in
out
.
clients
:
if
type
(
cl
)
==
str
:
store_steps
[
i
]
=
0
break
...
...
@@ -546,7 +563,7 @@ class ScanSaveMem(gof.Optimizer):
store_steps
[
i
]
=
0
break
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
this_slice
[
0
]
.
start
is
None
):
store_steps
[
i
]
=
0
break
...
...
@@ -559,46 +576,48 @@ class ScanSaveMem(gof.Optimizer):
except
Exception
:
length
=
out
.
shape
[
0
]
cf_slice
=
tensor
.
basic
.
get_canonical_form_slice
(
this_slice
[
0
],
length
)
this_slice
[
0
],
length
)
if
isinstance
(
cf_slice
[
0
],
slice
):
start
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
start
)
start
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
start
)
else
:
start
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
])
if
start
==
0
or
store_steps
[
i
]
==
0
:
store_steps
[
i
]
=
0
else
:
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
])
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
])
if
store_steps
[
i
]
!=
-
1
:
pval
=
select_max
(
pval
,
store_steps
[
i
])
store_steps
[
i
]
=
pval
flag_store
=
True
orphane_outs
=
[
i
for
i
,
x
in
enumerate
(
store_steps
)
if
(
type
(
x
)
is
int
)
and
(
x
<
0
)
]
flag_store
=
flag_store
or
(
len
(
orphane_outs
)
>
0
)
orphane_outs
=
[
i
for
i
,
x
in
enumerate
(
store_steps
)
if
(
type
(
x
)
is
int
)
and
(
x
<
0
)
]
flag_store
=
flag_store
or
(
len
(
orphane_outs
)
>
0
)
# 3. is there anything to change ?
if
(
flag_store
or
global_nsteps
is
not
None
):
# 3.1 initialize inputs for the new scan
old_outputs
=
[]
nw_inputs
=
list
(
node
.
inputs
)
old_outputs
=
[]
nw_inputs
=
list
(
node
.
inputs
)
nw_inputs
[
0
]
=
nw_steps
# 3.2 check orphane outputs to see if we can eliminate any
required
,
not_required
=
\
scan_utils
.
scan_can_remove_outs
(
node
.
op
,
orphane_outs
)
required
,
not_required
=
\
scan_utils
.
scan_can_remove_outs
(
node
.
op
,
orphane_outs
)
# 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required
# by the inner function .. )
replaced_outs
=
[]
offset
=
1
+
op
.
n_seqs
+
op
.
n_mit_mot
for
idx
,
_val
in
enumerate
(
store_steps
[
op
.
n_mit_mot
:]):
for
idx
,
_val
in
enumerate
(
store_steps
[
op
.
n_mit_mot
:]):
i
=
idx
+
op
.
n_mit_mot
if
not
(
type
(
_val
)
is
int
and
_val
<=
0
and
i
not
in
required
):
if
not
(
type
(
_val
)
is
int
and
_val
<=
0
and
i
not
in
required
):
if
idx
+
op
.
n_mit_mot
in
required
:
if
idx
+
op
.
n_mit_mot
in
required
:
val
=
1
else
:
val
=
_val
...
...
@@ -610,21 +629,21 @@ class ScanSaveMem(gof.Optimizer):
# a) the input is a set_subtensor, in that case we
# can replace the initial tensor by a slice,
# b) it is not, and we simply take a slice of it.
if
(
nw_inputs
[
offset
+
idx
]
.
owner
and
isinstance
(
nw_inputs
[
offset
+
idx
]
.
owner
.
op
,
if
(
nw_inputs
[
offset
+
idx
]
.
owner
and
isinstance
(
nw_inputs
[
offset
+
idx
]
.
owner
.
op
,
tensor
.
IncSubtensor
)):
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tensor
.
as_tensor_variable
(
val
-
init_l
[
i
]))
tmp
=
pre_constant_merge
([
tmp
])[
0
]
nw_input
=
scan_utils
.
expand
(
_nw_input
,
tmp
)
nw_input
=
scan_utils
.
expand
(
_nw_input
,
tmp
)
else
:
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tensor
.
as_tensor_variable
(
val
))
tmp
=
pre_constant_merge
([
tmp
])[
0
]
nw_input
=
nw_inputs
[
offset
+
idx
][:
tmp
]
nw_input
=
nw_inputs
[
offset
+
idx
][:
tmp
]
nw_inputs
[
offset
+
idx
]
=
nw_input
nw_inputs
[
offset
+
idx
]
=
nw_input
replaced_outs
.
append
(
op
.
n_mit_mot
+
idx
)
odx
=
op
.
n_mit_mot
+
idx
old_outputs
+=
[(
odx
,
[
x
[
0
]
.
outputs
[
0
]
for
x
in
...
...
@@ -632,8 +651,8 @@ class ScanSaveMem(gof.Optimizer):
# If there is no memory pre-allocated for this output
elif
idx
<
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
:
pos
=
(
op
.
n_mit_mot
+
idx
+
op
.
n_seqs
+
1
+
op
.
n_shared_outs
)
pos
=
(
op
.
n_mit_mot
+
idx
+
op
.
n_seqs
+
1
+
op
.
n_shared_outs
)
if
nw_inputs
[
pos
]
==
node
.
inputs
[
0
]:
nw_inputs
[
pos
]
=
val
odx
=
op
.
n_mit_mot
+
idx
...
...
@@ -646,43 +665,41 @@ class ScanSaveMem(gof.Optimizer):
for
idx
,
val
in
enumerate
(
store_steps
[
op
.
n_mit_mot
:]):
if
val
==
0
:
if
idx
<
op
.
n_mit_sot
+
op
.
n_sit_sot
:
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
odx
=
op
.
n_mit_mot
+
idx
nw_input
=
scan_utils
.
expand
(
_nw_input
,
nw_steps
)
nw_inputs
[
offset
+
idx
]
=
nw_input
nw_inputs
[
offset
+
idx
]
=
nw_input
elif
idx
<
(
op
.
n_mit_sot
+
op
.
n_sit_sot
+
+
op
.
n_nit_sot
):
in_idx
=
offset
+
idx
+
op
.
n_shared_outs
op
.
n_nit_sot
):
in_idx
=
offset
+
idx
+
op
.
n_shared_outs
if
nw_inputs
[
in_idx
]
==
node
.
inputs
[
0
]:
nw_inputs
[
in_idx
]
=
nw_steps
nw_inputs
[
in_idx
]
=
nw_steps
odx
=
op
.
n_mit_mot
+
idx
# 3.5 Remove unwanted orphane outputs
(
inps
,
outs
,
info
,
node_ins
,
compress_map
)
=
\
scan_utils
.
compress_outs
(
op
,
not_required
,
nw_inputs
)
inv_compress_map
=
{}
for
k
,
v
in
compress_map
.
items
():
for
k
,
v
in
compress_map
.
items
():
inv_compress_map
[
v
]
=
k
node_ins
=
[
pre_greedy_local_optimizer
(
list_opt_slice
,
x
)
for
x
in
node_ins
=
[
pre_greedy_local_optimizer
(
list_opt_slice
,
x
)
for
x
in
node_ins
]
node_ins
=
pre_constant_merge
(
node_ins
)
# 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
info
[
'_scan_merge_visited'
]
=
True
new_outs
=
scan_op
.
Scan
(
inps
,
outs
,
info
)
.
make_node
(
*
node_ins
)
.
outputs
new_outs
=
scan_op
.
Scan
(
inps
,
outs
,
info
)
.
make_node
(
*
node_ins
)
.
outputs
old_new
=
[]
# 3.7 Get replace pairs for those outputs that do not change
# the number of intermediate steps stored
for
idx
,
sl
in
enumerate
(
slices
):
for
idx
,
sl
in
enumerate
(
slices
):
if
global_nsteps
and
sl
is
not
None
and
store_steps
[
idx
]
==
0
:
for
hdx
,
cl
in
enumerate
(
node
.
outputs
[
idx
]
.
clients
):
for
hdx
,
cl
in
enumerate
(
node
.
outputs
[
idx
]
.
clients
):
cnf_slice
,
old_slices
=
sl
[
hdx
]
# Sanitize the nw_slice by converting ints back into
# constants :) I only need to do this for the first
...
...
@@ -697,18 +714,16 @@ class ScanSaveMem(gof.Optimizer):
else
:
fslice
=
sanitize
(
cnf_slice
[
0
])
nw_slice
=
(
fslice
,)
+
tuple
(
old_slices
[
1
:])
nw_pos
=
inv_compress_map
[
idx
]
nw_out
=
new_outs
[
nw_pos
]
subtens
=
tensor
.
basic
.
Subtensor
(
nw_slice
)
# slice inputs
sl_ins
=
tensor
.
basic
.
Subtensor
.
collapse
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
tensor
.
Variable
))
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
tensor
.
Variable
))
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
*
sl_ins
)
.
outputs
[
0
]
if
new_o
.
ndim
>
0
:
...
...
@@ -721,34 +736,35 @@ class ScanSaveMem(gof.Optimizer):
if
len
(
old_outs
)
>
0
:
nw_pos
=
compress_map
[
pos
]
nw_out
=
new_outs
[
nw_pos
]
for
k
,
old
in
enumerate
(
old_outs
):
for
k
,
old
in
enumerate
(
old_outs
):
# Get the correct slice
cnf_slice
,
old_slices
=
slices
[
pos
][
k
]
if
type
(
cnf_slice
[
0
])
is
slice
:
start
=
(
cnf_slice
[
0
]
.
start
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
]
)
if
(
cnf_slice
[
0
]
.
stop
is
not
None
and
cnf_slice
[
0
]
.
stop
!=
sys
.
maxint
):
stop
=
(
cnf_slice
[
0
]
.
stop
-
nw_steps
-
start
=
(
cnf_slice
[
0
]
.
start
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
])
if
(
cnf_slice
[
0
]
.
stop
is
not
None
and
cnf_slice
[
0
]
.
stop
!=
sys
.
maxint
):
stop
=
(
cnf_slice
[
0
]
.
stop
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
])
else
:
stop
=
None
nw_slice
=
(
(
slice
(
sanitize
(
start
),
sanitize
(
stop
),
sanitize
(
cnf_slice
[
0
]
.
step
)),)
+
tuple
(
old_slices
[
1
:])
)
nw_slice
=
((
slice
(
sanitize
(
start
),
sanitize
(
stop
),
sanitize
(
cnf_slice
[
0
]
.
step
)),)
+
tuple
(
old_slices
[
1
:])
)
else
:
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
]
)
init_l
[
pos
]
+
store_steps
[
pos
]
)
nw_slice
=
(
sanitize
(
position
),)
+
tuple
(
old_slices
[
1
:])
nw_slice
=
(
sanitize
(
position
),)
+
\
tuple
(
old_slices
[
1
:])
subtens
=
tensor
.
basic
.
Subtensor
(
nw_slice
)
sl_ins
=
tensor
.
basic
.
Subtensor
.
collapse
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
tensor
.
Variable
))
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
tensor
.
Variable
))
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
*
sl_ins
)
.
outputs
[
0
]
if
new_o
.
ndim
>
0
:
...
...
@@ -757,13 +773,12 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes
if
flag_store
or
global_nsteps
is
not
None
:
for
idx
,
o
in
enumerate
(
node
.
outputs
):
for
idx
,
o
in
enumerate
(
node
.
outputs
):
if
not
(
idx
in
replaced_outs
)
and
not
idx
in
not_required
:
nw_pos
=
compress_map
[
idx
]
old_new
+=
[(
o
,
new_outs
[
nw_pos
])]
env
.
replace_all_validate
(
old_new
,
reason
=
'scan_save_mem'
)
old_new
+=
[(
o
,
new_outs
[
nw_pos
])]
env
.
replace_all_validate
(
old_new
,
reason
=
'scan_save_mem'
)
def
apply
(
self
,
env
):
...
...
@@ -776,16 +791,16 @@ class ScanSaveMem(gof.Optimizer):
# Just before specialize to have the other optimization
# like constant folding being applied
# This don't introduce inplace.
scan_seqopt
.
register
(
'scanOp_save_mem'
,
ScanSaveMem
(),
4
,
'fast_run'
,
'scan'
)
scan_seqopt
.
register
(
'scanOp_save_mem'
,
ScanSaveMem
(),
4
,
'fast_run'
,
'scan'
)
class
ScanMerge
(
gof
.
Optimizer
):
""" Graph Optimizer that merges different scan ops """
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
())
def
merge
(
self
,
nodes
):
...
...
@@ -796,29 +811,26 @@ class ScanMerge(gof.Optimizer):
else
:
as_while
=
False
info
=
{}
info
[
'tap_array'
]
=
[]
info
[
'n_seqs'
]
=
sum
([
nd
.
op
.
n_seqs
for
nd
in
nodes
])
info
[
'n_mit_mot'
]
=
sum
([
nd
.
op
.
n_mit_mot
for
nd
in
nodes
])
info
[
'n_mit_mot_outs'
]
=
sum
([
nd
.
op
.
n_mit_mot_outs
for
nd
in
nodes
])
info
=
{}
info
[
'tap_array'
]
=
[]
info
[
'n_seqs'
]
=
sum
([
nd
.
op
.
n_seqs
for
nd
in
nodes
])
info
[
'n_mit_mot'
]
=
sum
([
nd
.
op
.
n_mit_mot
for
nd
in
nodes
])
info
[
'n_mit_mot_outs'
]
=
sum
([
nd
.
op
.
n_mit_mot_outs
for
nd
in
nodes
])
info
[
'mit_mot_out_slices'
]
=
[]
info
[
'n_mit_sot'
]
=
sum
([
nd
.
op
.
n_mit_sot
for
nd
in
nodes
])
info
[
'n_sit_sot'
]
=
sum
([
nd
.
op
.
n_sit_sot
for
nd
in
nodes
])
info
[
'n_shared_outs'
]
=
sum
([
nd
.
op
.
n_shared_outs
for
nd
in
nodes
])
info
[
'n_nit_sot'
]
=
sum
([
nd
.
op
.
n_nit_sot
for
nd
in
nodes
])
info
[
'truncate_gradient'
]
=
nodes
[
0
]
.
op
.
truncate_gradient
info
[
'name'
]
=
'&'
.
join
([
nd
.
op
.
name
for
nd
in
nodes
])
info
[
'mode'
]
=
nodes
[
0
]
.
op
.
mode
info
[
'inplace'
]
=
False
info
[
'gpu'
]
=
False
info
[
'as_while'
]
=
as_while
info
[
'profile'
]
=
nodes
[
0
]
.
op
.
profile
inner_ins
=
[]
outer_ins
=
[]
info
[
'n_mit_sot'
]
=
sum
([
nd
.
op
.
n_mit_sot
for
nd
in
nodes
])
info
[
'n_sit_sot'
]
=
sum
([
nd
.
op
.
n_sit_sot
for
nd
in
nodes
])
info
[
'n_shared_outs'
]
=
sum
([
nd
.
op
.
n_shared_outs
for
nd
in
nodes
])
info
[
'n_nit_sot'
]
=
sum
([
nd
.
op
.
n_nit_sot
for
nd
in
nodes
])
info
[
'truncate_gradient'
]
=
nodes
[
0
]
.
op
.
truncate_gradient
info
[
'name'
]
=
'&'
.
join
([
nd
.
op
.
name
for
nd
in
nodes
])
info
[
'mode'
]
=
nodes
[
0
]
.
op
.
mode
info
[
'inplace'
]
=
False
info
[
'gpu'
]
=
False
info
[
'as_while'
]
=
as_while
info
[
'profile'
]
=
nodes
[
0
]
.
op
.
profile
inner_ins
=
[]
outer_ins
=
[]
inner_outs
=
[]
outer_outs
=
[]
...
...
@@ -828,57 +840,56 @@ class ScanMerge(gof.Optimizer):
k
.
name
+=
str
(
suffix
)
return
ls
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Seq
inner_ins
+=
rename
(
nd
.
op
.
inner_seqs
(),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_seqs
(
nd
),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_seqs
(),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_seqs
(
nd
),
idx
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# MitMot
inner_ins
+=
rename
(
nd
.
op
.
inner_mitmot
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_mitmot
(),
idx
)
inner_outs
+=
nd
.
op
.
inner_mitmot_outs
()
info
[
'tap_array'
]
+=
nd
.
op
.
mitmot_taps
()
info
[
'mit_mot_out_slices'
]
+=
nd
.
op
.
mitmot_out_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitmot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_mitmot
(
nd
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitmot_outs
(
nd
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# MitSot
inner_ins
+=
rename
(
nd
.
op
.
inner_mitsot
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_mitsot
(),
idx
)
inner_outs
+=
nd
.
op
.
inner_mitsot_outs
()
info
[
'tap_array'
]
+=
nd
.
op
.
mitsot_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitsot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_mitsot
(
nd
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitsot_outs
(
nd
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# SitSot
inner_ins
+=
rename
(
nd
.
op
.
inner_sitsot
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_sitsot
(),
idx
)
info
[
'tap_array'
]
+=
[[
-
1
]
for
x
in
xrange
(
nd
.
op
.
n_sit_sot
)]
inner_outs
+=
nd
.
op
.
inner_sitsot_outs
()
outer_ins
+=
rename
(
nd
.
op
.
outer_sitsot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_sitsot
(
nd
),
idx
)
outer_outs
+=
nd
.
op
.
outer_sitsot_outs
(
nd
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Shared
inner_ins
+=
rename
(
nd
.
op
.
inner_shared
(),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_shared
(
nd
),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_shared
(),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_shared
(
nd
),
idx
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# NitSot
inner_outs
+=
nd
.
op
.
inner_nitsot_outs
()
outer_ins
+=
rename
(
nd
.
op
.
outer_nitsot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_nitsot
(
nd
),
idx
)
outer_outs
+=
nd
.
op
.
outer_nitsot_outs
(
nd
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Shared
outer_outs
+=
nd
.
op
.
outer_shared_outs
(
nd
)
inner_outs
+=
nd
.
op
.
inner_shared_outs
()
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Non Seqs
inner_ins
+=
rename
(
nd
.
op
.
inner_non_seqs
(),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_non_seqs
(
nd
),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_non_seqs
(),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_non_seqs
(
nd
),
idx
)
# Add back the number of steps
outer_ins
=
[
nodes
[
0
]
.
inputs
[
0
]]
+
outer_ins
...
...
@@ -897,8 +908,6 @@ class ScanMerge(gof.Optimizer):
return
zip
(
outer_outs
,
new_outs
)
def
belongs_to_set
(
self
,
node
,
set_nodes
):
"""
This function checks if node `node` belongs to `set_nodes`, in the
...
...
@@ -918,7 +927,6 @@ class ScanMerge(gof.Optimizer):
except
TypeError
:
pass
rep_nsteps
=
rep
.
inputs
[
0
]
try
:
rep_nsteps
=
int
(
get_constant_value
(
rep_nsteps
))
...
...
@@ -943,11 +951,9 @@ class ScanMerge(gof.Optimizer):
rep
.
op
.
inputs
)
return
same_cond
and
(
nsteps
==
rep_nsteps
)
and
can_add
def
apply
(
self
,
env
):
# Collect all scan nodes ordered according to toposort
scan_nodes
=
[
nd
for
nd
in
env
.
toposort
()
scan_nodes
=
[
nd
for
nd
in
env
.
toposort
()
if
isinstance
(
nd
.
op
,
scan_op
.
Scan
)]
# All sets of possibly mergeable nodes
...
...
@@ -955,7 +961,7 @@ class ScanMerge(gof.Optimizer):
for
nd
in
scan_nodes
:
belongs_to_set_idx
=
-
1
for
pos
,
subset
in
enumerate
(
all_sets
):
for
pos
,
subset
in
enumerate
(
all_sets
):
if
self
.
belongs_to_set
(
nd
,
subset
):
assert
belongs_to_set_idx
==
-
1
belongs_to_set_idx
=
pos
...
...
@@ -968,7 +974,7 @@ class ScanMerge(gof.Optimizer):
for
subset
in
all_sets
:
if
len
(
subset
)
>
1
:
proposal
=
self
.
merge
(
subset
)
env
.
replace_all_validate
(
proposal
,
reason
=
'scan_merge'
)
env
.
replace_all_validate
(
proposal
,
reason
=
'scan_merge'
)
# after const merge but before stabilize so that we can have identity
...
...
@@ -980,23 +986,27 @@ scan_seqopt.register('scanOp_merge',
'fast_run'
,
'scan'
)
def
has_duplicates
(
l
):
"""returns true if l has any duplicates (according to __eq__)."""
return
len
(
set
(
l
))
<
len
(
l
)
def
make_equiv
(
lo
,
li
):
"""builds a dictionary of equivalences between inner inputs based on the equivalence of their corresponding outer inputs."""
"""builds a dictionary of equivalences between inner inputs based on
the equivalence of their corresponding outer inputs."""
seeno
=
{}
left
=
[]
left
=
[]
right
=
[]
for
o
,
i
in
zip
(
lo
,
li
):
if
o
in
seeno
:
left
+=
[
i
]
left
+=
[
i
]
right
+=
[
o
]
else
:
seeno
[
o
]
=
i
return
left
,
right
@gof.local_optimizer
([
None
])
def
scan_merge_inouts
(
node
):
if
not
isinstance
(
node
.
op
,
scan_op
.
Scan
):
...
...
@@ -1056,58 +1066,68 @@ def scan_merge_inouts(node):
na
=
a
# start again
left
=
[]
left
=
[]
right
=
[]
if
has_duplicates
(
na
.
outer_in_shared
):
_left
,
_right
=
make_equiv
(
na
.
outer_in_shared
,
na
.
inner_in_shared
)
left
+=
_left
left
+=
_left
right
+=
_right
if
has_duplicates
(
na
.
outer_in_sit_sot
):
_left
,
_right
=
make_equiv
(
na
.
outer_in_sit_sot
,
na
.
inner_in_sit_sot
)
left
+=
_left
left
+=
_left
right
+=
_right
if
has_duplicates
(
na
.
outer_in_mit_mot
):
seen
=
{}
for
omm
,
imm
,
_sl
in
zip
(
na
.
outer_in_mit_mot
,
na
.
inner_in_mit_mot
,
na
.
mit_mot_in_slices
):
for
omm
,
imm
,
_sl
in
zip
(
na
.
outer_in_mit_mot
,
na
.
inner_in_mit_mot
,
na
.
mit_mot_in_slices
):
sl
=
tuple
(
_sl
)
if
(
omm
,
sl
)
in
seen
:
simm
=
seen
[(
omm
,
sl
)]
left
+=
imm
left
+=
imm
right
+=
simm
else
:
seen
[(
omm
,
sl
)]
=
imm
if
has_duplicates
(
na
.
outer_in_mit_sot
):
seen
=
{}
for
oms
,
ims
,
_sl
in
zip
(
na
.
outer_in_mit_sot
,
na
.
inner_in_mit_sot
,
na
.
mit_sot_in_slices
):
for
oms
,
ims
,
_sl
in
zip
(
na
.
outer_in_mit_sot
,
na
.
inner_in_mit_sot
,
na
.
mit_sot_in_slices
):
sl
=
tuple
(
_sl
)
if
(
oms
,
sl
)
in
seen
:
sims
=
seen
[(
oms
,
sl
)]
left
+=
ims
left
+=
ims
right
+=
sims
else
:
seen
[(
oms
,
sl
)]
=
ims
def
map_out
(
i
,
o
,
seen
):
for
si
,
so
in
seen
:
if
equal_computations
([
i
],
[
si
],
left
,
right
):
if
equal_computations
([
i
],
[
si
],
left
,
right
):
return
so
seen
.
append
((
i
,
o
))
return
o
seen
=
[]
na
.
outer_out_nit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_nit_sot
,
na
.
outer_out_nit_sot
)]
na
.
outer_out_nit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_nit_sot
,
na
.
outer_out_nit_sot
)]
seen
=
[]
na
.
outer_out_sit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_sit_sot
,
na
.
outer_out_sit_sot
)]
na
.
outer_out_sit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_sit_sot
,
na
.
outer_out_sit_sot
)]
seen
=
[]
na
.
outer_out_mit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_mit_sot
,
na
.
outer_out_mit_sot
)]
na
.
outer_out_mit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_mit_sot
,
na
.
outer_out_mit_sot
)]
seen
=
[]
new_outer_out_mit_mot
=
[]
for
imm
,
omm
,
osl
in
zip
(
na
.
inner_out_mit_mot
,
na
.
outer_out_mit_mot
,
na
.
mit_mot_out_slices
):
for
imm
,
omm
,
osl
in
zip
(
na
.
inner_out_mit_mot
,
na
.
outer_out_mit_mot
,
na
.
mit_mot_out_slices
):
for
simm
,
somm
,
sosl
in
seen
:
if
osl
==
sosl
and
equal_computations
(
imm
,
simm
,
left
,
right
):
new_outer_out_mit_mot
.
append
(
somm
)
...
...
@@ -1120,7 +1140,7 @@ def scan_merge_inouts(node):
return
na
.
outer_outputs
scan_seqopt
.
register
(
'scanOp_merge_inouts'
,
opt
.
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
opt
.
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
3
,
'fast_run'
,
'scan'
)
theano/scan_module/tests/test_scan.py
浏览文件 @
5a3a1d82
...
...
@@ -2260,6 +2260,40 @@ class T_Scan(unittest.TestCase):
assert
numpy
.
allclose
(
vnh0
,
tnh0
,
atol
=
1e-6
)
assert
numpy
.
allclose
(
vnW
,
tnW
,
atol
=
1e-6
)
def
test_pushout_all
(
self
):
W1
=
tensor
.
matrix
(
'W1'
)
W2
=
tensor
.
matrix
(
'W2'
)
h0
=
tensor
.
vector
(
'h0'
)
def
lambda_fn
(
h
,
W1
,
W2
):
return
tensor
.
dot
(
h
,
W1
+
W2
)
o
,
_
=
theano
.
scan
(
lambda_fn
,
non_sequences
=
[
h0
,
W1
,
W2
],
n_steps
=
5
)
f
=
theano
.
function
([
h0
,
W1
,
W2
],
o
,
mode
=
mode_with_opt
)
scan_nodes
=
[
x
for
x
in
f
.
maker
.
env
.
toposort
()
if
isinstance
(
x
.
op
,
theano
.
scan_module
.
scan_op
.
Scan
)]
assert
len
(
scan_nodes
)
==
0
seed
=
utt
.
fetch_seed
()
rng
=
numpy
.
random
.
RandomState
(
seed
)
floatX
=
theano
.
config
.
floatX
v_h
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,)),
dtype
=
floatX
)
v_W1
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,
2
)),
dtype
=
floatX
)
v_W2
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,
2
)),
dtype
=
floatX
)
v_out
=
numpy
.
dot
(
v_h
,
v_W1
+
v_W2
)
sol
=
numpy
.
zeros
((
5
,
2
))
# This line is here to make sol have the same shape as the output of
# theano. Note that what we ask theano to do is to repeat the 2
# elements vector v_out 5 times
sol
[:,:]
=
v_out
assert
numpy
.
allclose
(
sol
,
f
(
v_h
,
v_W1
,
v_W2
))
def
test_pushout
(
self
):
W1
=
tensor
.
matrix
(
'W1'
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论