Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5a3a1d82
提交
5a3a1d82
authored
11月 22, 2011
作者:
nouiz
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #219 from pascanur/better_pushout_optimization
Better pushout optimization
上级
518fd20d
0fe3b745
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
585 行增加
和
526 行删除
+585
-526
scan_op.py
theano/scan_module/scan_op.py
+267
-262
scan_opt.py
theano/scan_module/scan_opt.py
+284
-264
test_scan.py
theano/scan_module/tests/test_scan.py
+34
-0
没有找到文件。
theano/scan_module/scan_op.py
浏览文件 @
5a3a1d82
...
@@ -5,10 +5,10 @@ See scan.py for details on scan
...
@@ -5,10 +5,10 @@ See scan.py for details on scan
"""
"""
__docformat__
=
'restructedtext en'
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"Frederic Bastien "
"James Bergstra "
"James Bergstra "
"Pascal Lamblin "
)
"Pascal Lamblin "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
...
@@ -39,15 +39,11 @@ _logger = logging.getLogger('theano.scan_module.scan_op')
...
@@ -39,15 +39,11 @@ _logger = logging.getLogger('theano.scan_module.scan_op')
class
Scan
(
PureOp
):
class
Scan
(
PureOp
):
#
def
__init__
(
self
,
# OLD DOCUMENTATION CAN BE FOUND NEAR REVISION 2581
inputs
,
#
outputs
,
info
,
def
__init__
(
self
typeConstructor
=
None
,
,
inputs
,
outputs
,
info
,
typeConstructor
=
None
):
):
"""
"""
:param inputs: inputs of the inner function of scan
:param inputs: inputs of the inner function of scan
...
@@ -56,7 +52,7 @@ class Scan(PureOp):
...
@@ -56,7 +52,7 @@ class Scan(PureOp):
the scan op.
the scan op.
"""
"""
# adding properties into self
# adding properties into self
self
.
inputs
=
inputs
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
outputs
=
outputs
self
.
__dict__
.
update
(
info
)
self
.
__dict__
.
update
(
info
)
# I keep a version of info in self, to use in __eq__ and __hash__,
# I keep a version of info in self, to use in __eq__ and __hash__,
...
@@ -70,15 +66,16 @@ class Scan(PureOp):
...
@@ -70,15 +66,16 @@ class Scan(PureOp):
jdx
=
0
jdx
=
0
if
typeConstructor
is
None
:
if
typeConstructor
is
None
:
typeConstructor
=
lambda
broadcastable
,
dtype
:
TensorType
(
typeConstructor
=
lambda
broadcastable
,
dtype
:
TensorType
(
broadcastable
=
broadcastable
,
dtype
=
dtype
)
broadcastable
=
broadcastable
,
dtype
=
dtype
)
while
idx
<
self
.
n_mit_mot_outs
:
while
idx
<
self
.
n_mit_mot_outs
:
# Not that for mit_mot there are several output slices per
# Not that for mit_mot there are several output slices per
# output sequence
# output sequence
o
=
outputs
[
idx
]
o
=
outputs
[
idx
]
self
.
output_types
.
append
(
self
.
output_types
.
append
(
typeConstructor
(
broadcastable
=
(
False
,)
+
o
.
type
.
broadcastable
typeConstructor
(
,
dtype
=
o
.
type
.
dtype
)
broadcastable
=
(
False
,)
+
o
.
type
.
broadcastable
,
dtype
=
o
.
type
.
dtype
)
)
)
idx
+=
len
(
self
.
mit_mot_out_slices
[
jdx
])
idx
+=
len
(
self
.
mit_mot_out_slices
[
jdx
])
jdx
+=
1
jdx
+=
1
...
@@ -88,32 +85,32 @@ class Scan(PureOp):
...
@@ -88,32 +85,32 @@ class Scan(PureOp):
for
o
in
outputs
[
idx
:
end
]:
for
o
in
outputs
[
idx
:
end
]:
self
.
output_types
.
append
(
self
.
output_types
.
append
(
typeConstructor
(
typeConstructor
(
broadcastable
=
(
False
,)
+
o
.
type
.
broadcastable
broadcastable
=
(
False
,)
+
o
.
type
.
broadcastable
,
,
dtype
=
o
.
type
.
dtype
))
dtype
=
o
.
type
.
dtype
))
# shared outputs + possibly the ending condition
# shared outputs + possibly the ending condition
for
o
in
outputs
[
end
:]:
for
o
in
outputs
[
end
:]:
self
.
output_types
.
append
(
o
.
type
)
self
.
output_types
.
append
(
o
.
type
)
if
self
.
as_while
:
if
self
.
as_while
:
self
.
output_types
=
self
.
output_types
[:
-
1
]
self
.
output_types
=
self
.
output_types
[:
-
1
]
self
.
destroy_map
=
{}
self
.
destroy_map
=
{}
if
hasattr
(
self
,
'inplace'
)
and
self
.
inplace
:
if
hasattr
(
self
,
'inplace'
)
and
self
.
inplace
:
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
):
self
.
n_sit_sot
):
self
.
destroy_map
[
idx
]
=
[
idx
+
1
+
self
.
n_seqs
]
self
.
destroy_map
[
idx
]
=
[
idx
+
1
+
self
.
n_seqs
]
mode_instance
=
compile
.
mode
.
get_mode
(
self
.
mode
)
mode_instance
=
compile
.
mode
.
get_mode
(
self
.
mode
)
# if the default mode is used, and that mode is ProfileMode
# if the default mode is used, and that mode is ProfileMode
# then we need to copy the mode otherwise the time for a given
# then we need to copy the mode otherwise the time for a given
# op will be counted multiple times
# op will be counted multiple times
if
(
self
.
mode
is
None
and
if
(
self
.
mode
is
None
and
isinstance
(
mode_instance
,
compile
.
profilemode
.
ProfileMode
)
):
isinstance
(
mode_instance
,
compile
.
profilemode
.
ProfileMode
)):
mode_instance
=
compile
.
profilemode
.
ProfileMode
(
mode_instance
=
compile
.
profilemode
.
ProfileMode
(
optimizer
=
mode_instance
.
provided_optimizer
optimizer
=
mode_instance
.
provided_optimizer
,
,
linker
=
mode_instance
.
provided_linker
)
linker
=
mode_instance
.
provided_linker
)
compile
.
profilemode
.
prof_mode_instance_to_print
.
append
(
mode_instance
)
compile
.
profilemode
.
prof_mode_instance_to_print
.
append
(
mode_instance
)
self
.
mode_instance
=
mode_instance
self
.
mode_instance
=
mode_instance
if
self
.
name
:
if
self
.
name
:
self
.
mode_instance
.
message
=
self
.
name
+
" sub profile"
self
.
mode_instance
.
message
=
self
.
name
+
" sub profile"
...
@@ -122,7 +119,7 @@ class Scan(PureOp):
...
@@ -122,7 +119,7 @@ class Scan(PureOp):
else
:
else
:
self
.
mode_instance
=
mode_instance
self
.
mode_instance
=
mode_instance
if
not
hasattr
(
self
,
'name'
)
or
self
.
name
is
None
:
if
not
hasattr
(
self
,
'name'
)
or
self
.
name
is
None
:
self
.
name
=
'scan_fn'
self
.
name
=
'scan_fn'
# to have a fair __eq__ comparison later on, we update the info with
# to have a fair __eq__ comparison later on, we update the info with
# the actual mode used to compile the function and the name of the
# the actual mode used to compile the function and the name of the
...
@@ -130,27 +127,26 @@ class Scan(PureOp):
...
@@ -130,27 +127,26 @@ class Scan(PureOp):
self
.
info
[
'name'
]
=
self
.
name
self
.
info
[
'name'
]
=
self
.
name
# Pre-computing some values to speed up perform
# Pre-computing some values to speed up perform
self
.
mintaps
=
[
numpy
.
min
(
x
)
for
x
in
self
.
tap_array
]
self
.
mintaps
=
[
numpy
.
min
(
x
)
for
x
in
self
.
tap_array
]
self
.
mintaps
+=
[
0
for
x
in
xrange
(
self
.
n_nit_sot
)
]
self
.
mintaps
+=
[
0
for
x
in
xrange
(
self
.
n_nit_sot
)
]
self
.
seqs_arg_offset
=
1
+
self
.
n_seqs
self
.
seqs_arg_offset
=
1
+
self
.
n_seqs
self
.
shared_arg_offset
=
(
self
.
seqs_arg_offset
self
.
shared_arg_offset
=
(
self
.
seqs_arg_offset
+
+
self
.
n_mit_mot
self
.
n_mit_mot
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
self
.
nit_sot_arg_offset
=
(
self
.
shared_arg_offset
+
self
.
nit_sot_arg_offset
=
(
self
.
shared_arg_offset
+
self
.
n_shared_outs
)
self
.
n_shared_outs
)
self
.
n_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
self
.
n_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
self
.
n_tap_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
self
.
n_tap_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
if
not
self
.
info
[
'gpu'
]:
if
not
self
.
info
[
'gpu'
]:
tmp_in
,
tmp_out
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
tmp_in
,
tmp_out
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
self
.
outputs
)
self
.
outputs
)
local_env
=
gof
.
Env
(
tmp_in
,
tmp_out
)
local_env
=
gof
.
Env
(
tmp_in
,
tmp_out
)
self
.
_cmodule_key
=
gof
.
CLinker
.
cmodule_key_
(
local_env
,[])
self
.
_cmodule_key
=
gof
.
CLinker
.
cmodule_key_
(
local_env
,
[])
self
.
_hash_inner_graph
=
hash
(
self
.
_cmodule_key
)
self
.
_hash_inner_graph
=
hash
(
self
.
_cmodule_key
)
else
:
else
:
self
.
_hash_inner_graph
=
self
.
info
[
'gpu_hash'
]
self
.
_hash_inner_graph
=
self
.
info
[
'gpu_hash'
]
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
assert
numpy
.
all
(
isinstance
(
i
,
gof
.
Variable
)
for
i
in
inputs
)
assert
numpy
.
all
(
isinstance
(
i
,
gof
.
Variable
)
for
i
in
inputs
)
# assert dtype is consistent
# assert dtype is consistent
...
@@ -173,23 +169,23 @@ class Scan(PureOp):
...
@@ -173,23 +169,23 @@ class Scan(PureOp):
# Flags that indicate which inputs are vectors
# Flags that indicate which inputs are vectors
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
inputs
[
1
:
1
+
self
.
n_seqs
]
]
inputs
[
1
:
1
+
self
.
n_seqs
]
]
self
.
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
self
.
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
inputs
[
1
+
self
.
n_seqs
:
(
1
+
self
.
n_seqs
+
inputs
[
1
+
self
.
n_seqs
:
(
1
+
self
.
n_seqs
+
self
.
n_outs
)]
]
self
.
n_outs
)]]
self
.
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
self
.
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
# Check if input sequences and variables representing a slice of
# Check if input sequences and variables representing a slice of
# them have the same dtype
# them have the same dtype
for
idx
in
xrange
(
self
.
n_seqs
):
for
idx
in
xrange
(
self
.
n_seqs
):
if
inputs
[
1
+
idx
]
.
dtype
!=
self
.
inputs
[
idx
]
.
dtype
:
if
inputs
[
1
+
idx
]
.
dtype
!=
self
.
inputs
[
idx
]
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'sequence'
raise
ValueError
(
err_msg1
%
(
'sequence'
,
,
str
(
inputs
[
1
+
idx
])
str
(
inputs
[
1
+
idx
]),
,
idx
idx
,
,
inputs
[
1
+
idx
]
.
dtype
inputs
[
1
+
idx
]
.
dtype
,
,
str
(
self
.
inputs
[
idx
])
str
(
self
.
inputs
[
idx
]),
,
self
.
inputs
[
idx
]
.
dtype
)
)
self
.
inputs
[
idx
]
.
dtype
)
)
# Check that this 3 things have the same dtype for mit_mot:
# Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output
# - initial state of the output
...
@@ -198,73 +194,73 @@ class Scan(PureOp):
...
@@ -198,73 +194,73 @@ class Scan(PureOp):
# Maybe checking that ndim fits would be good as well !?
# Maybe checking that ndim fits would be good as well !?
index_i
=
self
.
n_seqs
index_i
=
self
.
n_seqs
index_o
=
0
index_o
=
0
index
=
1
+
self
.
n_seqs
index
=
1
+
self
.
n_seqs
start
=
index
start
=
index
end
=
index
+
self
.
n_mit_mot
end
=
index
+
self
.
n_mit_mot
while
index
<
end
:
while
index
<
end
:
for
k
in
self
.
tap_array
[
index
-
start
]:
for
k
in
self
.
tap_array
[
index
-
start
]:
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
' in scan nomenclature) '
,
,
str
(
inputs
[
index
])
str
(
inputs
[
index
]),
,
index
index
,
,
inputs
[
index
]
.
dtype
inputs
[
index
]
.
dtype
,
,
str
(
self
.
inputs
[
index_i
])
str
(
self
.
inputs
[
index_i
]),
,
self
.
inputs
[
index_i
]
.
dtype
)
)
self
.
inputs
[
index_i
]
.
dtype
)
)
index_i
+=
1
index_i
+=
1
for
k
in
self
.
mit_mot_out_slices
[
index
-
start
]:
for
k
in
self
.
mit_mot_out_slices
[
index
-
start
]:
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
])
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
]),
,
index
index
,
,
inputs
[
index
]
.
dtype
inputs
[
index
]
.
dtype
,
,
self
.
outputs
[
index_o
]
.
dtype
)
)
self
.
outputs
[
index_o
]
.
dtype
)
)
index_o
+=
1
index_o
+=
1
index
+=
1
index
+=
1
# Same checks as above but for outputs of type mit_sot and sit_sot
# Same checks as above but for outputs of type mit_sot and sit_sot
end
+=
self
.
n_mit_sot
+
self
.
n_sit_sot
end
+=
self
.
n_mit_sot
+
self
.
n_sit_sot
while
index
<
end
:
while
index
<
end
:
for
k
in
self
.
tap_array
[
index
-
start
]:
for
k
in
self
.
tap_array
[
index
-
start
]:
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'Initial state'
raise
ValueError
(
err_msg1
%
(
'Initial state'
,
,
str
(
inputs
[
index
])
str
(
inputs
[
index
]),
,
index
index
,
,
inputs
[
index
]
.
dtype
inputs
[
index
]
.
dtype
,
,
str
(
self
.
inputs
[
index_i
])
str
(
self
.
inputs
[
index_i
]),
,
self
.
inputs
[
index_i
]
.
dtype
)
)
self
.
inputs
[
index_i
]
.
dtype
)
)
index_i
+=
1
index_i
+=
1
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
])
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
]),
,
index
index
,
,
inputs
[
index
]
.
dtype
inputs
[
index
]
.
dtype
,
,
self
.
outputs
[
index_o
]
.
dtype
)
)
self
.
outputs
[
index_o
]
.
dtype
)
)
index_o
+=
1
index_o
+=
1
index
+=
1
index
+=
1
# Check that the shared variable and their update rule have the same
# Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?!
# dtype. Maybe even same type ?!
end
+=
self
.
n_shared_outs
end
+=
self
.
n_shared_outs
index_o
+=
self
.
n_nit_sot
index_o
+=
self
.
n_nit_sot
while
index
<
end
:
while
index
<
end
:
if
(
hasattr
(
inputs
[
index
],
'dtype'
)
and
if
(
hasattr
(
inputs
[
index
],
'dtype'
)
and
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
):
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
):
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
])
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
]),
,
index
index
,
,
inputs
[
index
]
.
dtype
inputs
[
index
]
.
dtype
,
,
self
.
outputs
[
index_o
]
.
dtype
)
)
self
.
outputs
[
index_o
]
.
dtype
)
)
index
+=
1
index
+=
1
index_o
+=
1
index_o
+=
1
for
x
in
inputs
[
index
:
index
+
self
.
n_nit_sot
]:
for
x
in
inputs
[
index
:
index
+
self
.
n_nit_sot
]:
# For every nit_sot input we get as input a int/uint that
# For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is
# depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization
# used by truncated BPTT and by scan space optimization
if
(
str
(
x
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
or
if
(
str
(
x
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
or
x
.
ndim
!=
0
):
x
.
ndim
!=
0
):
raise
ValueError
(
'For output
%
d you need to provide a '
raise
ValueError
(
'For output
%
d you need to provide a '
'scalar int !'
,
x
)
'scalar int !'
,
x
)
apply_node
=
Apply
(
self
apply_node
=
Apply
(
self
,
,
inputs
inputs
,
,
[
t
()
for
t
in
self
.
output_types
])
[
t
()
for
t
in
self
.
output_types
])
return
apply_node
return
apply_node
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -284,7 +280,7 @@ class Scan(PureOp):
...
@@ -284,7 +280,7 @@ class Scan(PureOp):
# check. Namely, do the internal graph represent same
# check. Namely, do the internal graph represent same
# computations
# computations
for
self_in
,
other_in
in
zip
(
self
.
inputs
,
other
.
inputs
):
for
self_in
,
other_in
in
zip
(
self
.
inputs
,
other
.
inputs
):
if
self_in
.
type
!=
other_in
.
type
:
if
self_in
.
type
!=
other_in
.
type
:
return
False
return
False
if
not
scan_utils
.
equal_computations
(
self
.
outputs
,
if
not
scan_utils
.
equal_computations
(
self
.
outputs
,
...
@@ -308,21 +304,19 @@ class Scan(PureOp):
...
@@ -308,21 +304,19 @@ class Scan(PureOp):
else
:
else
:
name
=
'for'
name
=
'for'
if
self
.
inplace
:
if
self
.
inplace
:
aux_txt
=
'
%
s{inplace,
%
s,
%
s}'
%
(
name
,
gpu_str
,
str
(
self
.
name
))
aux_txt
=
'
%
s{inplace,
%
s,
%
s}'
%
(
name
,
gpu_str
,
str
(
self
.
name
))
else
:
else
:
aux_txt
=
'
%
s{
%
s,
%
s}'
%
(
name
,
gpu_str
,
str
(
self
.
name
))
aux_txt
=
'
%
s{
%
s,
%
s}'
%
(
name
,
gpu_str
,
str
(
self
.
name
))
return
aux_txt
return
aux_txt
def
__hash__
(
self
):
def
__hash__
(
self
):
return
(
hash
(
type
(
self
))
^
return
(
hash
(
type
(
self
))
^
# and a hash representing the inner graph using the
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
# CLinker.cmodule_key_
self
.
_hash_inner_graph
^
self
.
_hash_inner_graph
^
scan_utils
.
hash_listsDictsTuples
(
self
.
info
)
)
scan_utils
.
hash_listsDictsTuples
(
self
.
info
))
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
"""
"""
...
@@ -348,7 +342,6 @@ class Scan(PureOp):
...
@@ -348,7 +342,6 @@ class Scan(PureOp):
# Setting up all my variables in what I believe is a more Cython
# Setting up all my variables in what I believe is a more Cython
# friendly form
# friendly form
node_input_storage
=
[
storage_map
[
r
]
for
r
in
node
.
inputs
]
node_input_storage
=
[
storage_map
[
r
]
for
r
in
node
.
inputs
]
node_output_storage
=
[
storage_map
[
r
]
for
r
in
node
.
outputs
]
node_output_storage
=
[
storage_map
[
r
]
for
r
in
node
.
outputs
]
node_input_compute
=
[
compute_map
[
r
]
for
r
in
node
.
inputs
]
node_input_compute
=
[
compute_map
[
r
]
for
r
in
node
.
inputs
]
...
@@ -357,64 +350,65 @@ class Scan(PureOp):
...
@@ -357,64 +350,65 @@ class Scan(PureOp):
# If a shared variable is the result of a ViewOp it is a clear
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# indication that we need to copy that value after the perform of
# scan is done
# scan is done
slices
=
(
self
.
n_mit_mot_outs
+
slices
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
self
.
n_nit_sot
)
wrapped_inputs
=
[
Param
(
x
,
borrow
=
True
)
for
x
in
self
.
inputs
]
wrapped_inputs
=
[
Param
(
x
,
borrow
=
True
)
for
x
in
self
.
inputs
]
wrapped_outputs
=
[
Out
(
x
,
borrow
=
True
)
for
x
in
wrapped_outputs
=
[
Out
(
x
,
borrow
=
True
)
for
x
in
self
.
outputs
[:
slices
]
]
self
.
outputs
[:
slices
]]
wrapped_outputs
+=
self
.
outputs
[
slices
:]
wrapped_outputs
+=
self
.
outputs
[
slices
:]
profile
=
None
profile
=
None
if
(
theano
.
config
.
profile
or
(
isinstance
(
self
.
profile
,
(
basestring
,
bool
,
int
))
if
(
theano
.
config
.
profile
or
(
isinstance
(
self
.
profile
,
(
basestring
,
bool
,
int
))
and
self
.
profile
)):
and
self
.
profile
)):
if
isinstance
(
self
.
profile
,
basestring
):
if
isinstance
(
self
.
profile
,
basestring
):
profile
=
ScanProfileStats
(
name
=
self
.
profile
)
profile
=
ScanProfileStats
(
name
=
self
.
profile
)
else
:
else
:
profile
=
ScanProfileStats
(
name
=
self
.
name
)
profile
=
ScanProfileStats
(
name
=
self
.
name
)
elif
self
.
profile
:
elif
self
.
profile
:
profile
=
self
.
profile
profile
=
self
.
profile
self
.
fn
=
function
(
wrapped_inputs
,
self
.
fn
=
function
(
wrapped_inputs
,
wrapped_outputs
,
wrapped_outputs
,
mode
=
self
.
mode_instance
,
mode
=
self
.
mode_instance
,
name
=
self
.
name
,
name
=
self
.
name
,
profile
=
profile
)
profile
=
profile
)
try
:
try
:
cython_mintaps
=
numpy
.
asarray
(
self
.
mintaps
,
dtype
=
'int32'
)
raise
ImportError
cython_mintaps
=
numpy
.
asarray
(
self
.
mintaps
,
dtype
=
'int32'
)
cython_tap_array_len
=
\
cython_tap_array_len
=
\
numpy
.
asarray
([
len
(
x
)
for
x
in
self
.
tap_array
],
numpy
.
asarray
([
len
(
x
)
for
x
in
self
.
tap_array
],
dtype
=
'int32'
)
dtype
=
'int32'
)
if
len
(
self
.
tap_array
)
==
0
:
if
len
(
self
.
tap_array
)
==
0
:
d1
=
0
d1
=
0
else
:
else
:
d1
=
numpy
.
max
(
cython_tap_array_len
)
d1
=
numpy
.
max
(
cython_tap_array_len
)
d0
=
len
(
self
.
tap_array
)
d0
=
len
(
self
.
tap_array
)
cython_tap_array
=
numpy
.
zeros
((
d0
,
d1
),
dtype
=
'int32'
)
cython_tap_array
=
numpy
.
zeros
((
d0
,
d1
),
dtype
=
'int32'
)
for
_d0
in
range
(
d0
):
for
_d0
in
range
(
d0
):
for
_d1
in
range
(
cython_tap_array_len
[
_d0
]):
for
_d1
in
range
(
cython_tap_array_len
[
_d0
]):
cython_tap_array
[
_d0
,
_d1
]
=
self
.
tap_array
[
_d0
][
_d1
]
cython_tap_array
[
_d0
,
_d1
]
=
self
.
tap_array
[
_d0
][
_d1
]
cython_mit_mot_out_nslices
=
\
cython_mit_mot_out_nslices
=
\
numpy
.
asarray
([
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
],
numpy
.
asarray
([
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
],
dtype
=
'int32'
)
dtype
=
'int32'
)
if
len
(
self
.
mit_mot_out_slices
)
==
0
:
if
len
(
self
.
mit_mot_out_slices
)
==
0
:
d1
=
0
d1
=
0
else
:
else
:
d1
=
numpy
.
max
(
cython_mit_mot_out_nslices
)
d1
=
numpy
.
max
(
cython_mit_mot_out_nslices
)
d0
=
len
(
self
.
mit_mot_out_slices
)
d0
=
len
(
self
.
mit_mot_out_slices
)
cython_mit_mot_out_slices
=
numpy
.
zeros
((
d0
,
d1
),
cython_mit_mot_out_slices
=
numpy
.
zeros
((
d0
,
d1
),
dtype
=
'int32'
)
dtype
=
'int32'
)
for
_d0
in
range
(
d0
):
for
_d0
in
range
(
d0
):
for
_d1
in
range
(
cython_mit_mot_out_nslices
[
_d0
]):
for
_d1
in
range
(
cython_mit_mot_out_nslices
[
_d0
]):
cython_mit_mot_out_slices
[
_d0
,
_d1
]
=
\
cython_mit_mot_out_slices
[
_d0
,
_d1
]
=
\
self
.
mit_mot_out_slices
[
_d0
][
_d1
]
self
.
mit_mot_out_slices
[
_d0
][
_d1
]
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
node
.
inputs
[
1
:
1
+
self
.
n_seqs
]
]
node
.
inputs
[
1
:
1
+
self
.
n_seqs
]
]
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
node
.
inputs
[
1
+
self
.
n_seqs
:
(
1
+
self
.
n_seqs
+
node
.
inputs
[
1
+
self
.
n_seqs
:
self
.
n_outs
)]
]
(
1
+
self
.
n_seqs
+
self
.
n_outs
)]
]
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
cython_vector_seqs
=
numpy
.
asarray
(
self
.
vector_seqs
,
cython_vector_seqs
=
numpy
.
asarray
(
self
.
vector_seqs
,
dtype
=
'int32'
)
dtype
=
'int32'
)
...
@@ -448,6 +442,7 @@ class Scan(PureOp):
...
@@ -448,6 +442,7 @@ class Scan(PureOp):
except
ImportError
:
except
ImportError
:
p
=
self
.
execute
p
=
self
.
execute
# default arguments are stored in the closure of `rval`
# default arguments are stored in the closure of `rval`
def
rval
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
):
def
rval
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
):
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
for
o
in
node
.
outputs
:
for
o
in
node
.
outputs
:
...
@@ -463,14 +458,14 @@ class Scan(PureOp):
...
@@ -463,14 +458,14 @@ class Scan(PureOp):
return
self
.
inputs
[:
self
.
n_seqs
]
return
self
.
inputs
[:
self
.
n_seqs
]
def
outer_seqs
(
self
,
node
):
def
outer_seqs
(
self
,
node
):
return
node
.
inputs
[
1
:
1
+
self
.
n_seqs
]
return
node
.
inputs
[
1
:
1
+
self
.
n_seqs
]
def
inner_mitmot
(
self
):
def
inner_mitmot
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
return
self
.
inputs
[
self
.
n_seqs
:
self
.
n_seqs
+
n_taps
]
return
self
.
inputs
[
self
.
n_seqs
:
self
.
n_seqs
+
n_taps
]
def
outer_mitmot
(
self
,
node
):
def
outer_mitmot
(
self
,
node
):
return
node
.
inputs
[
1
+
self
.
n_seqs
:
1
+
self
.
n_seqs
+
self
.
n_mit_mot
]
return
node
.
inputs
[
1
+
self
.
n_seqs
:
1
+
self
.
n_seqs
+
self
.
n_mit_mot
]
def
inner_mitmot_outs
(
self
):
def
inner_mitmot_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
...
@@ -490,80 +485,80 @@ class Scan(PureOp):
...
@@ -490,80 +485,80 @@ class Scan(PureOp):
ntaps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
ntaps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_sot
)])
return
self
.
inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
return
self
.
inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
self
.
n_seqs
+
ntaps_upto_sit_sot
]
self
.
n_seqs
+
ntaps_upto_sit_sot
]
def
outer_mitsot
(
self
,
node
):
def
outer_mitsot
(
self
,
node
):
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
return
node
.
inputs
[
offset
:
offset
+
self
.
n_mit_sot
]
return
node
.
inputs
[
offset
:
offset
+
self
.
n_mit_sot
]
def
inner_mitsot_outs
(
self
):
def
inner_mitsot_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
return
self
.
outputs
[
n_taps
:
n_taps
+
self
.
n_mit_sot
]
return
self
.
outputs
[
n_taps
:
n_taps
+
self
.
n_mit_sot
]
def
outer_mitsot_outs
(
self
,
node
):
def
outer_mitsot_outs
(
self
,
node
):
return
node
.
outputs
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
return
node
.
outputs
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
mitsot_taps
(
self
):
def
mitsot_taps
(
self
):
return
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
return
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
inner_sitsot
(
self
):
def
inner_sitsot
(
self
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_sot
)])
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
return
self
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
self
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot
(
self
,
node
):
def
outer_sitsot
(
self
,
node
):
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
return
node
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
node
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
inner_sitsot_outs
(
self
):
def
inner_sitsot_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
offset
=
self
.
n_mit_sot
+
n_taps
return
self
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
self
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot_outs
(
self
,
node
):
def
outer_sitsot_outs
(
self
,
node
):
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
return
node
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
node
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_nitsot
(
self
,
node
):
def
outer_nitsot
(
self
,
node
):
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
self
.
n_sit_sot
+
self
.
n_shared_outs
)
return
node
.
inputs
[
offset
:
offset
+
self
.
n_nit_sot
]
return
node
.
inputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_nitsot_outs
(
self
):
def
inner_nitsot_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
return
self
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
return
self
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
outer_nitsot_outs
(
self
,
node
):
def
outer_nitsot_outs
(
self
,
node
):
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
return
node
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
return
node
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_shared
(
self
):
def
inner_shared
(
self
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_sot
)])
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
return
self
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
self
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
outer_shared
(
self
,
node
):
def
outer_shared
(
self
,
node
):
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
return
node
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
node
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_shared_outs
(
self
):
def
inner_shared_outs
(
self
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
+
self
.
n_nit_sot
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
+
self
.
n_nit_sot
return
self
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
self
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
outer_shared_outs
(
self
,
node
):
def
outer_shared_outs
(
self
,
node
):
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
self
.
n_nit_sot
)
return
node
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
node
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_non_seqs
(
self
):
def
inner_non_seqs
(
self
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
...
@@ -574,12 +569,11 @@ class Scan(PureOp):
...
@@ -574,12 +569,11 @@ class Scan(PureOp):
return
self
.
inputs
[
offset
:]
return
self
.
inputs
[
offset
:]
def
outer_non_seqs
(
self
,
node
):
def
outer_non_seqs
(
self
,
node
):
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
return
node
.
inputs
[
offset
:]
return
node
.
inputs
[
offset
:]
def
execute
(
self
,
node
,
args
,
outs
):
def
execute
(
self
,
node
,
args
,
outs
):
"""
"""
The args are packed like this:
The args are packed like this:
...
@@ -607,7 +601,7 @@ class Scan(PureOp):
...
@@ -607,7 +601,7 @@ class Scan(PureOp):
# negative flip sequences around, and make n_steps positive
# negative flip sequences around, and make n_steps positive
t0_call
=
time
.
time
()
t0_call
=
time
.
time
()
t_fn
=
0
t_fn
=
0
n_steps
=
args
[
0
]
n_steps
=
args
[
0
]
seqs
=
[]
seqs
=
[]
if
n_steps
<
0
:
if
n_steps
<
0
:
n_steps
=
abs
(
n_steps
)
n_steps
=
abs
(
n_steps
)
...
@@ -616,7 +610,7 @@ class Scan(PureOp):
...
@@ -616,7 +610,7 @@ class Scan(PureOp):
raise
ValueError
((
'Sequence is shorter then the required '
raise
ValueError
((
'Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'number of steps : (n_steps, seq, '
'seq.shape):'
),
n_steps
,
'seq.shape):'
),
n_steps
,
node
.
inputs
[
1
+
idx
],
node
.
inputs
[
1
+
idx
],
seq
.
shape
)
seq
.
shape
)
seqs
.
append
(
seq
[::
-
1
])
seqs
.
append
(
seq
[::
-
1
])
else
:
else
:
...
@@ -625,35 +619,37 @@ class Scan(PureOp):
...
@@ -625,35 +619,37 @@ class Scan(PureOp):
raise
ValueError
((
'Sequence is shorter then the required '
raise
ValueError
((
'Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'number of steps : (n_steps, seq, '
'seq.shape):'
),
n_steps
,
'seq.shape):'
),
n_steps
,
node
.
inputs
[
1
+
idx
],
node
.
inputs
[
1
+
idx
],
seq
.
shape
)
seq
.
shape
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
# 2. Allocate memory for the outputs. Construct the list:
# 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containting the length of each output
# store_steps -- map containting the length of each output
# pos -- map containing the current position of each output
# pos -- map containing the current position of each
# output
store_steps
=
[
arg
.
shape
[
0
]
for
arg
store_steps
=
[
arg
.
shape
[
0
]
for
arg
in
args
[
self
.
seqs_arg_offset
:
in
args
[
self
.
seqs_arg_offset
:
self
.
shared_arg_offset
]
]
self
.
shared_arg_offset
]]
store_steps
+=
[
arg
for
arg
in
store_steps
+=
[
arg
for
arg
in
args
[
self
.
nit_sot_arg_offset
:
args
[
self
.
nit_sot_arg_offset
:
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
]
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
]
]
]
pos
=
[
(
-
self
.
mintaps
[
idx
])
%
store_steps
[
idx
]
for
idx
pos
=
[
(
-
self
.
mintaps
[
idx
])
%
store_steps
[
idx
]
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
)]
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
)]
# 2.1 Create storage space for outputs
# 2.1 Create storage space for outputs
for
idx
in
xrange
(
self
.
n_outs
):
for
idx
in
xrange
(
self
.
n_outs
):
if
self
.
inplace
:
if
self
.
inplace
:
# ^ Case 1. Outputs should be computed inplace of their
# ^ Case 1. Outputs should be computed inplace of their
# initial state
# initial state
outs
[
idx
][
0
]
=
args
[
self
.
seqs_arg_offset
+
idx
]
outs
[
idx
][
0
]
=
args
[
self
.
seqs_arg_offset
+
idx
]
elif
(
outs
[
idx
][
0
]
is
not
None
and
elif
(
outs
[
idx
][
0
]
is
not
None
and
outs
[
idx
][
0
]
.
shape
[
1
:]
==
args
[
self
.
seqs_arg_offset
+
idx
]
.
shape
[
1
:]
outs
[
idx
][
0
]
.
shape
[
1
:]
==
args
[
self
.
seqs_arg_offset
+
and
outs
[
idx
][
0
]
.
shape
[
0
]
>=
store_steps
[
idx
]
):
idx
]
.
shape
[
1
:]
and
outs
[
idx
][
0
]
.
shape
[
0
]
>=
store_steps
[
idx
]):
# Put in the values of the initial state
# Put in the values of the initial state
outs
[
idx
][
0
]
=
outs
[
idx
][
0
][:
store_steps
[
idx
]]
outs
[
idx
][
0
]
=
outs
[
idx
][
0
][:
store_steps
[
idx
]]
if
idx
>
self
.
n_mit_mot
:
if
idx
>
self
.
n_mit_mot
:
l
=
-
self
.
mintaps
[
idx
]
l
=
-
self
.
mintaps
[
idx
]
outs
[
idx
][
0
][:
l
]
=
args
[
self
.
seqs_arg_offset
+
idx
][:
l
]
outs
[
idx
][
0
][:
l
]
=
args
[
self
.
seqs_arg_offset
+
idx
][:
l
]
...
@@ -662,28 +658,28 @@ class Scan(PureOp):
...
@@ -662,28 +658,28 @@ class Scan(PureOp):
else
:
else
:
outs
[
idx
][
0
]
=
args
[
self
.
seqs_arg_offset
+
idx
]
.
copy
()
outs
[
idx
][
0
]
=
args
[
self
.
seqs_arg_offset
+
idx
]
.
copy
()
offset
=
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
offset
=
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
other_args
=
args
[
offset
:]
other_args
=
args
[
offset
:]
input_storage
=
self
.
fn
.
input_storage
input_storage
=
self
.
fn
.
input_storage
output_storage
=
self
.
fn
.
output_storage
output_storage
=
self
.
fn
.
output_storage
fn
=
self
.
fn
.
fn
fn
=
self
.
fn
.
fn
offset
=
(
self
.
n_seqs
+
sum
(
map
(
len
,
self
.
tap_array
[:
self
.
n_outs
]))
+
offset
=
(
self
.
n_seqs
+
sum
(
map
(
len
,
self
.
tap_array
[:
self
.
n_outs
]))
+
self
.
n_shared_outs
)
self
.
n_shared_outs
)
for
idx
in
xrange
(
len
(
other_args
)):
for
idx
in
xrange
(
len
(
other_args
)):
input_storage
[
idx
+
offset
]
.
storage
[
0
]
=
other_args
[
idx
]
input_storage
[
idx
+
offset
]
.
storage
[
0
]
=
other_args
[
idx
]
i
=
0
i
=
0
cond
=
True
cond
=
True
############## THE MAIN LOOP #########################
############## THE MAIN LOOP #########################
#for i in xrange(n_steps):
#for i in xrange(n_steps):
while
(
i
<
n_steps
)
and
cond
:
while
(
i
<
n_steps
)
and
cond
:
# sequences over which scan iterates
# sequences over which scan iterates
# 3. collect input slices
# 3. collect input slices
for
idx
in
xrange
(
self
.
n_seqs
):
for
idx
in
xrange
(
self
.
n_seqs
):
if
self
.
vector_seqs
[
idx
]:
if
self
.
vector_seqs
[
idx
]:
input_storage
[
idx
]
.
storage
[
0
]
=
seqs
[
idx
][
i
:
i
+
1
]
.
reshape
(())
input_storage
[
idx
]
.
storage
[
0
]
=
\
seqs
[
idx
][
i
:
i
+
1
]
.
reshape
(())
else
:
else
:
input_storage
[
idx
]
.
storage
[
0
]
=
seqs
[
idx
][
i
]
input_storage
[
idx
]
.
storage
[
0
]
=
seqs
[
idx
][
i
]
...
@@ -691,26 +687,25 @@ class Scan(PureOp):
...
@@ -691,26 +687,25 @@ class Scan(PureOp):
for
idx
in
xrange
(
self
.
n_outs
):
for
idx
in
xrange
(
self
.
n_outs
):
if
self
.
vector_outs
[
idx
]:
if
self
.
vector_outs
[
idx
]:
for
tap
in
self
.
tap_array
[
idx
]:
for
tap
in
self
.
tap_array
[
idx
]:
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
input_storage
[
offset
]
.
storage
[
0
]
=
\
input_storage
[
offset
]
.
storage
[
0
]
=
\
outs
[
idx
][
0
][
_idx
:
_idx
+
1
]
.
reshape
(())
outs
[
idx
][
0
][
_idx
:
_idx
+
1
]
.
reshape
(())
offset
+=
1
offset
+=
1
else
:
else
:
for
tap
in
self
.
tap_array
[
idx
]:
for
tap
in
self
.
tap_array
[
idx
]:
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
input_storage
[
offset
]
.
storage
[
0
]
=
outs
[
idx
][
0
][
_idx
]
input_storage
[
offset
]
.
storage
[
0
]
=
outs
[
idx
][
0
][
_idx
]
offset
+=
1
offset
+=
1
a_offset
=
self
.
shared_arg_offset
a_offset
=
self
.
shared_arg_offset
o_offset
=
self
.
n_outs
+
self
.
n_nit_sot
o_offset
=
self
.
n_outs
+
self
.
n_nit_sot
if
i
==
0
:
if
i
==
0
:
for
j
in
xrange
(
self
.
n_shared_outs
):
for
j
in
xrange
(
self
.
n_shared_outs
):
input_storage
[
offset
]
.
storage
[
0
]
=
args
[
a_offset
+
j
]
input_storage
[
offset
]
.
storage
[
0
]
=
args
[
a_offset
+
j
]
offset
+=
1
offset
+=
1
else
:
else
:
for
j
in
xrange
(
self
.
n_shared_outs
):
for
j
in
xrange
(
self
.
n_shared_outs
):
input_storage
[
offset
]
.
storage
[
0
]
=
outs
[
o_offset
+
j
][
0
]
input_storage
[
offset
]
.
storage
[
0
]
=
outs
[
o_offset
+
j
][
0
]
offset
+=
1
offset
+=
1
# 4. collecting slices where the output should be stored
# 4. collecting slices where the output should be stored
...
@@ -718,23 +713,24 @@ class Scan(PureOp):
...
@@ -718,23 +713,24 @@ class Scan(PureOp):
output_storage
[
idx
]
.
storage
[
0
]
=
None
output_storage
[
idx
]
.
storage
[
0
]
=
None
offset
=
self
.
n_mit_mot_outs
offset
=
self
.
n_mit_mot_outs
if
i
!=
0
and
self
.
n_nit_sot
>
0
:
if
i
!=
0
and
self
.
n_nit_sot
>
0
:
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
-
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
):
self
.
n_mit_mot
):
if
(
store_steps
[
idx
+
self
.
n_mit_mot
]
==
1
or
if
(
store_steps
[
idx
+
self
.
n_mit_mot
]
==
1
or
self
.
vector_outs
[
idx
+
self
.
n_mit_mot
]):
self
.
vector_outs
[
idx
+
self
.
n_mit_mot
]):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
else
:
else
:
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
\
_pos0
=
idx
+
self
.
n_mit_mot
outs
[
idx
+
self
.
n_mit_mot
][
0
][
pos
[
idx
+
self
.
n_mit_mot
]]
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
\
outs
[
_pos0
][
0
][
pos
[
_pos0
]]
else
:
else
:
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
-
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
):
self
.
n_mit_mot
):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
offset
+=
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
offset
+=
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
for
idx
in
xrange
(
self
.
n_shared_outs
):
for
idx
in
xrange
(
self
.
n_shared_outs
):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
# If condition add it to the mix
# If condition add it to the mix
if
self
.
as_while
:
if
self
.
as_while
:
pdx
=
offset
+
self
.
n_shared_outs
pdx
=
offset
+
self
.
n_shared_outs
...
@@ -762,97 +758,102 @@ class Scan(PureOp):
...
@@ -762,97 +758,102 @@ class Scan(PureOp):
# 5.1 Copy over the values for mit_mot outputs
# 5.1 Copy over the values for mit_mot outputs
for
j
in
xrange
(
self
.
n_mit_mot
):
for
j
in
xrange
(
self
.
n_mit_mot
):
for
k
in
self
.
mit_mot_out_slices
[
j
]:
for
k
in
self
.
mit_mot_out_slices
[
j
]:
outs
[
j
][
0
][
k
+
pos
[
j
]]
=
output_storage
[
offset_out
]
.
storage
[
0
]
outs
[
j
][
0
][
k
+
pos
[
j
]]
=
\
output_storage
[
offset_out
]
.
storage
[
0
]
offset_out
+=
1
offset_out
+=
1
# 5.2 Copy over the values for mit_sot/sit_sot outputs
# 5.2 Copy over the values for mit_sot/sit_sot outputs
begin
=
self
.
n_mit_mot
begin
=
self
.
n_mit_mot
end
=
self
.
n_outs
end
=
self
.
n_outs
offset_out
-=
self
.
n_mit_mot
offset_out
-=
self
.
n_mit_mot
for
j
in
xrange
(
begin
,
end
):
for
j
in
xrange
(
begin
,
end
):
if
(
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]
or
if
(
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]
or
outs
[
j
][
0
][
pos
[
j
]]
is
not
output_storage
[
offset_out
+
j
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
is
not
output_storage
[
offset_out
+
j
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
offset_out
+
j
]
.
storage
[
0
]
outs
[
j
][
0
][
pos
[
j
]]
=
\
output_storage
[
offset_out
+
j
]
.
storage
[
0
]
# 5.3 Copy over the values for nit_sot outputs
# 5.3 Copy over the values for nit_sot outputs
begin
=
end
begin
=
end
end
+=
self
.
n_nit_sot
end
+=
self
.
n_nit_sot
for
j
in
xrange
(
begin
,
end
):
for
j
in
xrange
(
begin
,
end
):
if
i
==
0
:
if
i
==
0
:
jout
=
j
+
offset_out
jout
=
j
+
offset_out
shape
=
(
store_steps
[
j
],)
+
output_storage
[
jout
]
.
storage
[
0
]
.
shape
shape
=
(
store_steps
[
j
],)
+
\
output_storage
[
jout
]
.
storage
[
0
]
.
shape
if
len
(
output_storage
[
jout
]
.
storage
[
0
]
.
shape
)
==
0
:
if
len
(
output_storage
[
jout
]
.
storage
[
0
]
.
shape
)
==
0
:
self
.
vector_outs
[
j
]
=
True
self
.
vector_outs
[
j
]
=
True
dtype
=
output_storage
[
jout
]
.
storage
[
0
]
.
dtype
dtype
=
output_storage
[
jout
]
.
storage
[
0
]
.
dtype
if
(
outs
[
j
][
0
]
is
None
or
if
(
outs
[
j
][
0
]
is
None
or
outs
[
j
][
0
]
.
shape
[
0
]
<
store_steps
[
j
]
or
outs
[
j
][
0
]
.
shape
[
0
]
<
store_steps
[
j
]
or
outs
[
j
][
0
]
.
shape
[
1
:]
!=
shape
[
1
:]
or
outs
[
j
][
0
]
.
shape
[
1
:]
!=
shape
[
1
:]
or
outs
[
j
][
0
]
.
dtype
!=
dtype
):
outs
[
j
][
0
]
.
dtype
!=
dtype
):
if
self
.
gpu
:
if
self
.
gpu
:
outs
[
j
][
0
]
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
(
shape
)
_cuda
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
outs
[
j
][
0
]
=
_cuda
.
zeros
(
shape
)
else
:
else
:
outs
[
j
][
0
]
=
numpy
.
zeros
(
shape
,
dtype
)
outs
[
j
][
0
]
=
numpy
.
zeros
(
shape
,
dtype
)
elif
outs
[
j
][
0
]
.
shape
[
0
]
!=
store_steps
[
j
]:
elif
outs
[
j
][
0
]
.
shape
[
0
]
!=
store_steps
[
j
]:
outs
[
j
][
0
]
=
outs
[
j
][
0
][:
store_steps
[
j
]]
outs
[
j
][
0
]
=
outs
[
j
][
0
][:
store_steps
[
j
]]
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
jout
]
.
storage
[
0
]
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
jout
]
.
storage
[
0
]
elif
(
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]
or
elif
(
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]
or
outs
[
j
][
0
][
pos
[
j
]]
is
not
output_storage
[
j
+
offset_out
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
is
not
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
j
+
offset_out
]
.
storage
[
0
]
output_storage
[
j
+
offset_out
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
=
\
output_storage
[
j
+
offset_out
]
.
storage
[
0
]
# 5.4 Copy over the values for outputs corresponding to shared
# 5.4 Copy over the values for outputs corresponding to shared
# variables
# variables
begin
=
end
begin
=
end
end
+=
self
.
n_shared_outs
end
+=
self
.
n_shared_outs
for
j
in
xrange
(
begin
,
end
):
for
j
in
xrange
(
begin
,
end
):
jout
=
j
+
offset_out
jout
=
j
+
offset_out
outs
[
j
][
0
]
=
output_storage
[
jout
]
.
storage
[
0
]
outs
[
j
][
0
]
=
output_storage
[
jout
]
.
storage
[
0
]
pos
=
[
(
idx
+
1
)
%
store
for
idx
,
store
in
pos
=
[(
idx
+
1
)
%
store
for
idx
,
store
in
itertools
.
izip
(
pos
,
store_steps
)
itertools
.
izip
(
pos
,
store_steps
)]
]
i
=
i
+
1
i
=
i
+
1
# 6. Check if you need to re-order output buffers
# 6. Check if you need to re-order output buffers
begin
=
self
.
n_mit_mot
begin
=
self
.
n_mit_mot
end
=
self
.
n_outs
+
self
.
n_nit_sot
end
=
self
.
n_outs
+
self
.
n_nit_sot
for
idx
in
xrange
(
begin
,
end
):
for
idx
in
xrange
(
begin
,
end
):
min_tap
=
self
.
mintaps
[
idx
]
min_tap
=
self
.
mintaps
[
idx
]
if
(
store_steps
[
idx
]
<
i
-
self
.
mintaps
[
idx
]
and
if
(
store_steps
[
idx
]
<
i
-
self
.
mintaps
[
idx
]
and
pos
[
idx
]
<
store_steps
[
idx
]
):
pos
[
idx
]
<
store_steps
[
idx
]):
pdx
=
pos
[
idx
]
pdx
=
pos
[
idx
]
if
pdx
<
store_steps
[
idx
]
//
2
:
if
pdx
<
store_steps
[
idx
]
//
2
:
shape
=
(
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
shape
=
(
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
cuda
.
CudaNdarray
):
cuda
.
CudaNdarray
):
tmp
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
(
shape
)
_cuda
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
tmp
=
_cuda
.
zeros
(
shape
)
else
:
else
:
tmp
=
numpy
.
empty
(
shape
)
tmp
=
numpy
.
empty
(
shape
)
tmp
[:]
=
outs
[
idx
][
0
][:
pdx
]
tmp
[:]
=
outs
[
idx
][
0
][:
pdx
]
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
outs
[
idx
][
0
][
pdx
:]
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
outs
[
idx
][
0
][
pdx
:]
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
tmp
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
tmp
del
tmp
del
tmp
else
:
else
:
shape
=
(
store_steps
[
idx
]
-
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
shape
=
(
store_steps
[
idx
]
-
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
cuda
.
CudaNdarray
):
cuda
.
CudaNdarray
):
tmp
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
(
shape
)
_cuda
=
cuda
.
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
tmp
=
_cuda
.
zeros
(
shape
)
else
:
else
:
tmp
=
numpy
.
empty
(
shape
)
tmp
=
numpy
.
empty
(
shape
)
tmp
[:]
=
outs
[
idx
][
0
][
pdx
:]
tmp
[:]
=
outs
[
idx
][
0
][
pdx
:]
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
outs
[
idx
][
0
][:
pdx
]
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
outs
[
idx
][
0
][:
pdx
]
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
tmp
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
tmp
del
tmp
del
tmp
# This would normally happen only when doing truncated
# This would normally happen only when doing truncated
# backpropagation through time. In such a scenarion Scan is
# backpropagation through time. In such a scenarion Scan is
# expected to return 0 for all entries for which the gradient is
# expected to return 0 for all entries for which the gradient is
# not actually computed
# not actually computed
elif
store_steps
[
idx
]
>
i
-
self
.
mintaps
[
idx
]:
elif
store_steps
[
idx
]
>
i
-
self
.
mintaps
[
idx
]:
outs
[
idx
][
0
][
i
-
self
.
mintaps
[
idx
]:]
=
0
outs
[
idx
][
0
][
i
-
self
.
mintaps
[
idx
]:]
=
0
# This is a fix for a bug introduced by while. If you say
# This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output
# you want to loop up to a condition, you expect the output
# to have that length ( and not the maximal length possible)
# to have that length ( and not the maximal length possible)
...
@@ -883,7 +884,7 @@ class Scan(PureOp):
...
@@ -883,7 +884,7 @@ class Scan(PureOp):
profile
.
callcount
+=
1
profile
.
callcount
+=
1
profile
.
nbsteps
+=
n_steps
profile
.
nbsteps
+=
n_steps
profile
.
call_time
+=
t_call
profile
.
call_time
+=
t_call
profile
.
vm_call_time
+=
t_fn
profile
.
vm_call_time
+=
t_fn
if
hasattr
(
self
.
fn
.
fn
,
'update_profile'
):
if
hasattr
(
self
.
fn
.
fn
,
'update_profile'
):
self
.
fn
.
fn
.
update_profile
(
profile
)
self
.
fn
.
fn
.
update_profile
(
profile
)
...
@@ -896,7 +897,7 @@ class Scan(PureOp):
...
@@ -896,7 +897,7 @@ class Scan(PureOp):
#self.fn.maker.mode.fn_time += t_fn
#self.fn.maker.mode.fn_time += t_fn
# Old Profile Mode */
# Old Profile Mode */
self
.
t_call
=
t_call
self
.
t_call
=
t_call
self
.
t_fn
=
t_fn
self
.
t_fn
=
t_fn
### Infer Shape
### Infer Shape
def
infer_shape
(
self
,
node
,
input_shapes
):
def
infer_shape
(
self
,
node
,
input_shapes
):
...
@@ -905,26 +906,27 @@ class Scan(PureOp):
...
@@ -905,26 +906,27 @@ class Scan(PureOp):
# is the shape of self.inputs[i]
# is the shape of self.inputs[i]
# sequences
# sequences
seqs_shape
=
[
x
[
1
:]
for
x
in
input_shapes
[
1
:
1
+
self
.
n_seqs
]
]
seqs_shape
=
[
x
[
1
:]
for
x
in
input_shapes
[
1
:
1
+
self
.
n_seqs
]
]
# mit_mot, mit_sot, sit_sot
# mit_mot, mit_sot, sit_sot
n_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
n_outs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
outs_shape
=
[]
outs_shape
=
[]
for
idx
in
xrange
(
n_outs
):
for
idx
in
xrange
(
n_outs
):
for
k
in
self
.
tap_array
[
idx
]:
for
k
in
self
.
tap_array
[
idx
]:
outs_shape
+=
[
input_shapes
[
idx
+
self
.
n_seqs
+
1
][
1
:]
]
outs_shape
+=
[
input_shapes
[
idx
+
self
.
n_seqs
+
1
][
1
:]
]
# shared_outs
# shared_outs
offset
=
1
+
self
.
n_seqs
+
n_outs
offset
=
1
+
self
.
n_seqs
+
n_outs
for
idx
in
xrange
(
self
.
n_shared_outs
):
for
idx
in
xrange
(
self
.
n_shared_outs
):
outs_shape
+=
[
input_shapes
[
idx
+
offset
]
]
outs_shape
+=
[
input_shapes
[
idx
+
offset
]
]
# non_sequences
# non_sequences
offset
+=
self
.
n_nit_sot
+
self
.
n_shared_outs
offset
+=
self
.
n_nit_sot
+
self
.
n_shared_outs
inner_ins_shapes
=
seqs_shape
+
outs_shape
+
input_shapes
[
offset
:]
inner_ins_shapes
=
seqs_shape
+
outs_shape
+
input_shapes
[
offset
:]
assert
len
(
inner_ins_shapes
)
==
len
(
self
.
inputs
)
assert
len
(
inner_ins_shapes
)
==
len
(
self
.
inputs
)
# Non-sequences have a direct equivalent from self.inputs in node.inputs
# Non-sequences have a direct equivalent from self.inputs in
# node.inputs
inner_non_sequences
=
self
.
inputs
[
len
(
seqs_shape
)
+
len
(
outs_shape
):]
inner_non_sequences
=
self
.
inputs
[
len
(
seqs_shape
)
+
len
(
outs_shape
):]
out_equivalent
=
{}
out_equivalent
=
{}
for
in_ns
,
out_ns
in
zip
(
inner_non_sequences
,
node
.
inputs
[
offset
:]):
for
in_ns
,
out_ns
in
zip
(
inner_non_sequences
,
node
.
inputs
[
offset
:]):
...
@@ -934,22 +936,22 @@ class Scan(PureOp):
...
@@ -934,22 +936,22 @@ class Scan(PureOp):
else
:
else
:
self_outs
=
self
.
outputs
self_outs
=
self
.
outputs
outs_shape
=
scan_utils
.
infer_shape
(
outs_shape
=
scan_utils
.
infer_shape
(
outs
=
self_outs
,
outs
=
self_outs
,
inputs
=
self
.
inputs
,
inputs
=
self
.
inputs
,
input_shapes
=
inner_ins_shapes
)
input_shapes
=
inner_ins_shapes
)
# Will be used to check if outs_shape can be expressed without using
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs.
# variables in self.inputs.
# The shapes of node.inputs are valid.
# The shapes of node.inputs are valid.
validator
=
scan_utils
.
Validator
(
validator
=
scan_utils
.
Validator
(
valid
=
input_shapes
,
valid
=
input_shapes
,
invalid
=
self
.
inputs
,
invalid
=
self
.
inputs
,
valid_equivalent
=
out_equivalent
)
valid_equivalent
=
out_equivalent
)
offset
=
1
+
self
.
n_seqs
offset
=
1
+
self
.
n_seqs
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
offset
+=
n_outs
offset
+=
n_outs
for
x
in
xrange
(
self
.
n_nit_sot
):
for
x
in
xrange
(
self
.
n_nit_sot
):
out_shape_x
=
outs_shape
[
n_outs
+
x
]
out_shape_x
=
outs_shape
[
n_outs
+
x
]
if
out_shape_x
is
None
:
if
out_shape_x
is
None
:
# This output is not a tensor, and has no shape
# This output is not a tensor, and has no shape
scan_outs
.
append
(
None
)
scan_outs
.
append
(
None
)
...
@@ -957,10 +959,10 @@ class Scan(PureOp):
...
@@ -957,10 +959,10 @@ class Scan(PureOp):
# We need to make sure that we can compute the shapes from
# We need to make sure that we can compute the shapes from
# node.inputs, and constants, without using the variables
# node.inputs, and constants, without using the variables
# in the inner function.
# in the inner function.
r
=
node
.
outputs
[
n_outs
+
x
]
r
=
node
.
outputs
[
n_outs
+
x
]
assert
r
.
ndim
==
1
+
len
(
out_shape_x
)
assert
r
.
ndim
==
1
+
len
(
out_shape_x
)
shp
=
[
node
.
inputs
[
offset
+
self
.
n_shared_outs
+
x
]]
shp
=
[
node
.
inputs
[
offset
+
self
.
n_shared_outs
+
x
]]
for
i
,
shp_i
in
zip
(
xrange
(
1
,
r
.
ndim
),
out_shape_x
):
for
i
,
shp_i
in
zip
(
xrange
(
1
,
r
.
ndim
),
out_shape_x
):
# Validate shp_i. v_shape_i is either None (if invalid),
# Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates
# or a (variable, Boolean) tuple. The Boolean indicates
# whether variable is shp_i (if True), or an valid
# whether variable is shp_i (if True), or an valid
...
@@ -976,34 +978,32 @@ class Scan(PureOp):
...
@@ -976,34 +978,32 @@ class Scan(PureOp):
shp
.
append
(
v_shp_i
[
0
])
shp
.
append
(
v_shp_i
[
0
])
scan_outs
.
append
(
tuple
(
shp
))
scan_outs
.
append
(
tuple
(
shp
))
scan_outs
+=
[
x
for
x
in
scan_outs
+=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
self
.
n_shared_outs
]
]
input_shapes
[
offset
:
offset
+
self
.
n_shared_outs
]
]
return
scan_outs
return
scan_outs
### GRAD FUNCTION
### GRAD FUNCTION
def
grad
(
self
,
args
,
g_outs
):
def
grad
(
self
,
args
,
g_outs
):
# 1. forward pass - get the outputs after applying scan
# 1. forward pass - get the outputs after applying scan
scan_outputs
=
self
(
*
args
)
scan_outputs
=
self
(
*
args
)
# 2. make sure they are given as a list
# 2. make sure they are given as a list
if
not
(
type
(
scan_outputs
)
in
(
list
,
tuple
)):
if
not
(
type
(
scan_outputs
)
in
(
list
,
tuple
)):
scan_outputs
=
[
scan_outputs
]
scan_outputs
=
[
scan_outputs
]
# 3. un-group / unzip the inputs
# 3. un-group / unzip the inputs
# Note ! We don't want to use the actual same variable as the ones
# Note ! We don't want to use the actual same variable as the ones
# used by the original scan, rather create clones of them
# used by the original scan, rather create clones of them
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
self
.
outputs
,
'_grad'
)
self
.
outputs
,
'_grad'
)
self_inputs
=
rval
[
0
]
self_inputs
=
rval
[
0
]
self_outputs
=
rval
[
1
]
self_outputs
=
rval
[
1
]
seqs
=
self_inputs
[:
self
.
n_seqs
]
seqs
=
self_inputs
[:
self
.
n_seqs
]
offset
=
self
.
n_seqs
n_ins_mit_mot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
offset
=
self
.
n_seqs
in
xrange
(
self
.
n_mit_mot
)])
n_ins_mit_mot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
in
xrange
(
self
.
n_mit_mot
)
])
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
offset
+=
n_ins_mit_mot
offset
+=
n_ins_mit_mot
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
...
@@ -1082,6 +1082,11 @@ class Scan(PureOp):
...
@@ -1082,6 +1082,11 @@ class Scan(PureOp):
# 7.3. compute gradients of the inputs given one output
# 7.3. compute gradients of the inputs given one output
for
dx
,
out
in
enumerate
(
clean_outputs
):
for
dx
,
out
in
enumerate
(
clean_outputs
):
inner_g_out
=
safe_new
(
out
)
inner_g_out
=
safe_new
(
out
)
###
#### I need to clip the gradient HERE !!
if
g_outs_no_shared
[
dx
]:
if
g_outs_no_shared
[
dx
]:
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
else
:
else
:
...
...
theano/scan_module/scan_opt.py
浏览文件 @
5a3a1d82
...
@@ -4,11 +4,11 @@ This module provides optimizations for scan
...
@@ -4,11 +4,11 @@ This module provides optimizations for scan
__docformat__
=
'restructedtext en'
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"Frederic Bastien "
"James Bergstra "
"James Bergstra "
"Pascal Lamblin "
"Pascal Lamblin "
"Arnaud Bergeron "
)
"Arnaud Bergeron "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
...
@@ -32,16 +32,20 @@ from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
...
@@ -32,16 +32,20 @@ from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
# Logging function for sending warning or info
# Logging function for sending warning or info
_logger
=
logging
.
getLogger
(
'theano.scan_module.scan_opt'
)
_logger
=
logging
.
getLogger
(
'theano.scan_module.scan_opt'
)
list_opt_slice
=
[
tensor
.
opt
.
local_abs_merge
,
list_opt_slice
=
[
tensor
.
opt
.
local_abs_merge
,
tensor
.
opt
.
local_mul_switch_sink
,
tensor
.
opt
.
local_mul_switch_sink
,
tensor
.
opt
.
local_upcast_elemwise_constant_inputs
,
tensor
.
opt
.
local_upcast_elemwise_constant_inputs
,
tensor
.
opt
.
local_remove_switch_const_cond
,
tensor
.
opt
.
local_remove_switch_const_cond
,
tensor
.
opt
.
constant_folding
]
tensor
.
opt
.
constant_folding
]
def
warning
(
*
msg
):
def
warning
(
*
msg
):
_logger
.
warning
(
'WARNING theano.scan: '
+
' '
.
join
(
msg
))
_logger
.
warning
(
'WARNING theano.scan: '
+
' '
.
join
(
msg
))
def
info
(
*
msg
):
def
info
(
*
msg
):
_logger
.
info
(
'INFO theano.scan: '
+
' '
.
join
(
msg
))
_logger
.
info
(
'INFO theano.scan: '
+
' '
.
join
(
msg
))
@gof.local_optimizer
([
None
])
@gof.local_optimizer
([
None
])
def
remove_constants_and_unused_inputs_scan
(
node
):
def
remove_constants_and_unused_inputs_scan
(
node
):
...
@@ -58,9 +62,9 @@ def remove_constants_and_unused_inputs_scan(node):
...
@@ -58,9 +62,9 @@ def remove_constants_and_unused_inputs_scan(node):
return
False
return
False
op
=
node
.
op
op
=
node
.
op
# We only need to take care of sequences and other arguments
# We only need to take care of sequences and other arguments
st
=
op
.
n_seqs
st
=
op
.
n_seqs
st
+=
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
st
+=
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]
]))
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]
]))
st
+=
op
.
n_sit_sot
st
+=
op
.
n_sit_sot
st
+=
op
.
n_shared_outs
st
+=
op
.
n_shared_outs
op_ins
,
op_outs
=
scan_utils
.
reconstruct_graph
(
op
.
inputs
,
op
.
outputs
)
op_ins
,
op_outs
=
scan_utils
.
reconstruct_graph
(
op
.
inputs
,
op
.
outputs
)
...
@@ -70,17 +74,17 @@ def remove_constants_and_unused_inputs_scan(node):
...
@@ -70,17 +74,17 @@ def remove_constants_and_unused_inputs_scan(node):
out_stuff_inner
=
op_ins
[
op
.
n_seqs
:
st
]
out_stuff_inner
=
op_ins
[
op
.
n_seqs
:
st
]
non_seqs
=
op_ins
[
st
:]
non_seqs
=
op_ins
[
st
:]
st
=
(
op
.
n_seqs
+
st
=
(
op
.
n_seqs
+
op
.
n_mit_mot
+
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
+
op
.
n_nit_sot
+
op
.
n_shared_outs
+
1
)
op
.
n_shared_outs
+
1
)
outer_non_seqs
=
node
.
inputs
[
st
:]
outer_non_seqs
=
node
.
inputs
[
st
:]
out_stuff_outer
=
node
.
inputs
[
1
+
op
.
n_seqs
:
st
]
out_stuff_outer
=
node
.
inputs
[
1
+
op
.
n_seqs
:
st
]
# To replace constants in the outer graph by clones in the inner graph
# To replace constants in the outer graph by clones in the inner graph
givens
=
{}
givens
=
{}
# All the inputs of the inner graph of the new scan
# All the inputs of the inner graph of the new scan
nw_inner
=
[]
nw_inner
=
[]
# Same for the outer graph, initialized w/ number of steps
# Same for the outer graph, initialized w/ number of steps
...
@@ -88,18 +92,18 @@ def remove_constants_and_unused_inputs_scan(node):
...
@@ -88,18 +92,18 @@ def remove_constants_and_unused_inputs_scan(node):
all_ins
=
gof
.
graph
.
inputs
(
op_outs
)
all_ins
=
gof
.
graph
.
inputs
(
op_outs
)
for
idx
in
xrange
(
op
.
n_seqs
):
for
idx
in
xrange
(
op
.
n_seqs
):
if
(
isinstance
(
node
.
inputs
[
idx
+
1
],
tensor
.
TensorConstant
)
and
if
(
isinstance
(
node
.
inputs
[
idx
+
1
],
tensor
.
TensorConstant
)
and
node
.
inputs
[
idx
+
1
]
.
tag
.
unique_value
is
not
None
):
node
.
inputs
[
idx
+
1
]
.
tag
.
unique_value
is
not
None
):
try
:
try
:
# This works if input is a constant that has all entries
# This works if input is a constant that has all entries
# equal
# equal
val
=
tensor
.
get_constant_value
(
node
.
inputs
[
idx
+
1
])
val
=
tensor
.
get_constant_value
(
node
.
inputs
[
idx
+
1
])
givens
[
op_ins
[
idx
]]
=
node
.
inputs
[
idx
+
1
]
.
clone
()[
0
]
givens
[
op_ins
[
idx
]]
=
node
.
inputs
[
idx
+
1
]
.
clone
()[
0
]
except
TypeError
:
except
TypeError
:
pass
pass
elif
op_ins
[
idx
]
in
all_ins
:
elif
op_ins
[
idx
]
in
all_ins
:
nw_inner
+=
[
op_ins
[
idx
]]
nw_inner
+=
[
op_ins
[
idx
]]
nw_outer
+=
[
node
.
inputs
[
idx
+
1
]]
nw_outer
+=
[
node
.
inputs
[
idx
+
1
]]
nw_n_seqs
=
len
(
nw_inner
)
nw_n_seqs
=
len
(
nw_inner
)
# Add outputs stuff
# Add outputs stuff
...
@@ -114,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(node):
...
@@ -114,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(node):
nw_outer
+=
[
nw_out
]
nw_outer
+=
[
nw_out
]
if
len
(
nw_inner
)
!=
len
(
op_ins
):
if
len
(
nw_inner
)
!=
len
(
op_ins
):
op_outs
=
scan_utils
.
clone
(
op_outs
,
replace
=
givens
)
op_outs
=
scan_utils
.
clone
(
op_outs
,
replace
=
givens
)
nw_info
=
op
.
info
.
copy
()
nw_info
=
op
.
info
.
copy
()
nw_info
[
'n_seqs'
]
=
nw_n_seqs
nw_info
[
'n_seqs'
]
=
nw_n_seqs
# DEBUG CHECK
# DEBUG CHECK
...
@@ -128,11 +132,12 @@ scan_seqopt = theano.gof.SequenceDB()
...
@@ -128,11 +132,12 @@ scan_seqopt = theano.gof.SequenceDB()
optdb
.
register
(
'scan_seqopt'
,
scan_seqopt
,
1.9
,
'fast_run'
,
'scan'
)
optdb
.
register
(
'scan_seqopt'
,
scan_seqopt
,
1.9
,
'fast_run'
,
'scan'
)
scan_seqopt
.
register
(
'scanOp_remove_constants_and_unused_inputs'
,
scan_seqopt
.
register
(
'scanOp_remove_constants_and_unused_inputs'
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
opt
.
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
ignore_newtrees
=
True
),
5
,
5
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
# This is a global opt for historical reason
# This is a global opt for historical reason
# It should be possible to change it to a local opt.
# It should be possible to change it to a local opt.
class
PushOutNonSeqScan
(
gof
.
Optimizer
):
class
PushOutNonSeqScan
(
gof
.
Optimizer
):
...
@@ -140,10 +145,9 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -140,10 +145,9 @@ class PushOutNonSeqScan(gof.Optimizer):
def
__init__
(
self
):
def
__init__
(
self
):
gof
.
Optimizer
.
__init__
(
self
)
gof
.
Optimizer
.
__init__
(
self
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
())
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
())
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
nodelist
=
[
x
for
x
in
env
.
toposort
()
if
isinstance
(
x
.
op
,
nodelist
=
[
x
for
x
in
env
.
toposort
()
if
isinstance
(
x
.
op
,
scan_op
.
Scan
)]
scan_op
.
Scan
)]
...
@@ -152,34 +156,31 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -152,34 +156,31 @@ class PushOutNonSeqScan(gof.Optimizer):
def
process_node
(
self
,
env
,
node
):
def
process_node
(
self
,
env
,
node
):
# this flag tells if there was any change during the last iterations
# this flag tells if there was any change during the last iterations
changed
=
True
changed
=
True
clean_inputs
,
clean_outputs
=
scan_utils
.
reconstruct_graph
(
clean_inputs
,
clean_outputs
=
scan_utils
.
reconstruct_graph
(
node
.
op
.
inputs
,
node
.
op
.
outputs
)
node
.
op
.
inputs
,
node
.
op
.
outputs
)
local_env
=
gof
.
Env
(
clean_inputs
,
clean_outputs
)
local_env
=
gof
.
Env
(
clean_inputs
,
clean_outputs
)
max_iterations
=
2
*
len
(
local_env
.
toposort
())
+
3
max_iterations
=
2
*
len
(
local_env
.
toposort
())
+
3
counts
=
0
counts
=
0
to_remove
=
[]
to_remove
=
[]
to_replace
=
[]
to_replace
=
[]
replace_with_in
=
[]
replace_with_in
=
[]
replace_with_out
=
[]
replace_with_out
=
[]
op
=
node
.
op
op
=
node
.
op
# Construct the list of non_sequences to simplify a few things
# Construct the list of non_sequences to simplify a few things
st
=
op
.
n_seqs
st
=
op
.
n_seqs
st
+=
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
st
+=
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]
]))
op
.
tap_array
[:(
op
.
n_mit_mot
+
op
.
n_mit_sot
)]
]))
st
+=
op
.
n_sit_sot
st
+=
op
.
n_sit_sot
st
+=
op
.
n_shared_outs
st
+=
op
.
n_shared_outs
non_seqs
=
clean_inputs
[
st
:]
non_seqs
=
clean_inputs
[
st
:]
st
=
(
op
.
n_seqs
+
st
=
(
op
.
n_seqs
+
op
.
n_mit_mot
+
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
+
op
.
n_nit_sot
+
op
.
n_shared_outs
+
1
)
op
.
n_shared_outs
+
1
)
outer_non_seqs
=
node
.
inputs
[
st
:]
outer_non_seqs
=
node
.
inputs
[
st
:]
assert
len
(
non_seqs
)
==
len
(
outer_non_seqs
)
assert
len
(
non_seqs
)
==
len
(
outer_non_seqs
)
while
changed
and
counts
<
max_iterations
:
while
changed
and
counts
<
max_iterations
:
...
@@ -187,15 +188,15 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -187,15 +188,15 @@ class PushOutNonSeqScan(gof.Optimizer):
changed
=
False
changed
=
False
for
nd
in
local_env
.
toposort
():
for
nd
in
local_env
.
toposort
():
if
(
numpy
.
all
([
(
x
in
non_seqs
)
or
if
(
numpy
.
all
([
(
x
in
non_seqs
)
or
(
x
.
owner
in
to_remove
)
or
(
x
.
owner
in
to_remove
)
or
isinstance
(
x
,
tensor
.
Constant
)
isinstance
(
x
,
tensor
.
Constant
)
for
x
in
nd
.
inputs
])
and
for
x
in
nd
.
inputs
])
and
# we can do this because the assumption is that a
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
# function and not somewhere in the middle ..
not
isinstance
(
nd
.
op
,
theano
.
compile
.
ViewOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
ViewOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
DeepCopyOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
DeepCopyOp
)
and
# and we didn't already looked at this node
# and we didn't already looked at this node
not
nd
in
to_remove
not
nd
in
to_remove
):
):
...
@@ -206,49 +207,50 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -206,49 +207,50 @@ class PushOutNonSeqScan(gof.Optimizer):
outside_ins
=
[]
outside_ins
=
[]
for
x
in
nd
.
inputs
:
for
x
in
nd
.
inputs
:
if
x
in
non_seqs
:
if
x
in
non_seqs
:
outside_ins
+=
[
outer_non_seqs
[
non_seqs
.
index
(
x
)]]
outside_ins
+=
[
outer_non_seqs
[
non_seqs
.
index
(
x
)]]
elif
x
in
to_replace
:
elif
x
in
to_replace
:
outside_ins
+=
[
replace_with_out
[
to_replace
.
index
(
x
)]]
outside_ins
+=
[
replace_with_out
[
to_replace
.
index
(
x
)]]
elif
isinstance
(
x
,
theano
.
Constant
):
elif
isinstance
(
x
,
theano
.
Constant
):
outside_ins
+=
[
x
.
clone
()]
outside_ins
+=
[
x
.
clone
()]
else
:
else
:
raise
Exception
(
raise
Exception
(
(
'Error in the `scan_pushout_non_seq_operations`'
(
'Error in the `scan_pushout_non_seq_'
'. The optimization tries to move some '
'operations`. The optimization tries '
'computation fron scan which is not allowed '
'to move some computation fron scan '
'to move. Report this on theano-users list'
),
x
)
'which is not allowed to move. Report '
'this on theano-users list'
),
x
)
nw_outer_node
=
nd
.
op
.
make_node
(
*
outside_ins
)
nw_outer_node
=
nd
.
op
.
make_node
(
*
outside_ins
)
# Step 2. Create variables for replacements
# Step 2. Create variables for replacements
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
for
idx
,
y
in
enumerate
(
nd
.
outputs
):
y_place_holder
=
scan_utils
.
safe_new
(
y
,
'_replace'
)
y_place_holder
=
scan_utils
.
safe_new
(
y
,
'_replace'
)
to_replace
+=
[
y
]
to_replace
+=
[
y
]
replace_with_in
+=
[
y_place_holder
]
replace_with_in
+=
[
y_place_holder
]
assert
type
(
y
)
==
type
(
nw_outer_node
.
outputs
[
idx
])
assert
type
(
y
)
==
type
(
nw_outer_node
.
outputs
[
idx
])
replace_with_out
+=
[
nw_outer_node
.
outputs
[
idx
]]
replace_with_out
+=
[
nw_outer_node
.
outputs
[
idx
]]
changed
=
True
changed
=
True
if
counts
>=
max_iterations
:
if
counts
>=
max_iterations
:
raise
Exception
(
(
'Error in the `scan_pushout_non_seq_operations`.'
raise
Exception
(
'Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
' The optimization exhausted the maximal number '
'of iterations allowed!'
)
)
'of iterations allowed!'
)
# We need to check all candidate replacements and choose those that
# We need to check all candidate replacements and choose those that
# make sense for us
# make sense for us
# Step 1. which elements of `to_replace` are used by remaining
# Step 1. which elements of `to_replace` are used by remaining
# components of the inner function
# components of the inner function
clean_to_replace
=
[]
clean_to_replace
=
[]
clean_replace_with_in
=
[]
clean_replace_with_in
=
[]
clean_replace_with_out
=
[]
clean_replace_with_out
=
[]
existent_nodes
=
[
nd
for
nd
in
local_env
.
toposort
()
existent_nodes
=
[
nd
for
nd
in
local_env
.
toposort
()
if
nd
not
in
to_remove
]
if
nd
not
in
to_remove
]
to_keep
=
[]
to_keep
=
[]
for
nd
in
existent_nodes
:
for
nd
in
existent_nodes
:
to_keep
+=
nd
.
inputs
to_keep
+=
nd
.
inputs
for
idx
,
out
in
enumerate
(
to_replace
):
for
idx
,
out
in
enumerate
(
to_replace
):
if
out
in
to_keep
and
out
.
owner
not
in
existent_nodes
:
if
out
in
to_keep
and
out
.
owner
not
in
existent_nodes
:
clean_to_replace
+=
[
out
]
clean_to_replace
+=
[
out
]
clean_replace_with_in
+=
[
replace_with_in
[
idx
]]
clean_replace_with_in
+=
[
replace_with_in
[
idx
]]
clean_replace_with_out
+=
[
replace_with_out
[
idx
]]
clean_replace_with_out
+=
[
replace_with_out
[
idx
]]
if
len
(
clean_to_replace
)
>
0
:
if
len
(
clean_to_replace
)
>
0
:
...
@@ -256,7 +258,7 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -256,7 +258,7 @@ class PushOutNonSeqScan(gof.Optimizer):
givens
=
{}
givens
=
{}
nw_outer
=
[]
nw_outer
=
[]
nw_inner
=
[]
nw_inner
=
[]
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
for
to_repl
,
repl_in
,
repl_out
in
zip
(
clean_to_replace
,
clean_replace_with_in
,
clean_replace_with_in
,
clean_replace_with_out
):
clean_replace_with_out
):
if
isinstance
(
repl_out
,
theano
.
Constant
):
if
isinstance
(
repl_out
,
theano
.
Constant
):
...
@@ -274,8 +276,24 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -274,8 +276,24 @@ class PushOutNonSeqScan(gof.Optimizer):
nwScan
=
scan_op
.
Scan
(
op_ins
,
op_outs
,
op
.
info
)
nwScan
=
scan_op
.
Scan
(
op_ins
,
op_outs
,
op
.
info
)
nw_node
=
nwScan
.
make_node
(
*
(
node
.
inputs
+
nw_outer
))
nw_node
=
nwScan
.
make_node
(
*
(
node
.
inputs
+
nw_outer
))
env
.
replace_all_validate
(
zip
(
node
.
outputs
,
nw_node
.
outputs
),
env
.
replace_all_validate
(
zip
(
node
.
outputs
,
nw_node
.
outputs
),
reason
=
'scan_push_computation_out'
)
reason
=
'scan_push_computation_out'
)
return
True
return
True
elif
to_keep
==
[]:
# Nothing in the inner graph should be kept
replace_with
=
{}
for
idx
,
out
in
enumerate
(
to_replace
):
if
out
in
local_env
.
outputs
:
x
=
node
.
outputs
[
local_env
.
outputs
.
index
(
out
)]
y
=
replace_with_out
[
idx
]
shape
=
[
y
.
shape
[
idx
]
for
idx
in
xrange
(
y
.
ndim
)]
replace_with
[
x
]
=
tensor
.
alloc
(
y
,
node
.
inputs
[
0
],
*
shape
)
# We need to add one extra dimension to the outputs
env
.
replace_all_validate
(
replace_with
.
items
(),
reason
=
'scan_push_computation_out'
)
else
:
else
:
return
False
return
False
...
@@ -290,17 +308,17 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops',
...
@@ -290,17 +308,17 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops',
@gof.local_optimizer
([
None
])
@gof.local_optimizer
([
None
])
def
scan_make_inplace
(
node
):
def
scan_make_inplace
(
node
):
op
=
node
.
op
op
=
node
.
op
if
(
isinstance
(
op
,
scan_op
.
Scan
)
and
if
(
isinstance
(
op
,
scan_op
.
Scan
)
and
(
not
op
.
info
[
'inplace'
])
and
(
not
op
.
info
[
'inplace'
])
and
(
not
op
.
info
[
'gpu'
])):
(
not
op
.
info
[
'gpu'
])):
info
=
op
.
info
.
copy
()
info
=
op
.
info
.
copy
()
info
[
'inplace'
]
=
True
info
[
'inplace'
]
=
True
# inputs corresponding to sequences and n_steps
# inputs corresponding to sequences and n_steps
ls_begin
=
node
.
inputs
[:
1
+
op
.
n_seqs
]
ls_begin
=
node
.
inputs
[:
1
+
op
.
n_seqs
]
ls
=
op
.
outer_mitmot
(
node
)
ls
=
op
.
outer_mitmot
(
node
)
ls
+=
op
.
outer_mitsot
(
node
)
ls
+=
op
.
outer_mitsot
(
node
)
ls
+=
op
.
outer_sitsot
(
node
)
ls
+=
op
.
outer_sitsot
(
node
)
ls_end
=
op
.
outer_shared
(
node
)
ls_end
=
op
.
outer_shared
(
node
)
ls_end
+=
op
.
outer_nitsot
(
node
)
ls_end
+=
op
.
outer_nitsot
(
node
)
ls_end
+=
op
.
outer_non_seqs
(
node
)
ls_end
+=
op
.
outer_non_seqs
(
node
)
n_outs
=
len
(
ls
)
n_outs
=
len
(
ls
)
...
@@ -309,19 +327,18 @@ def scan_make_inplace(node):
...
@@ -309,19 +327,18 @@ def scan_make_inplace(node):
ls
[
idx
]
=
deep_copy_op
(
ls
[
idx
])
ls
[
idx
]
=
deep_copy_op
(
ls
[
idx
])
inputs
=
ls_begin
+
ls
+
ls_end
inputs
=
ls_begin
+
ls
+
ls_end
new_op
=
scan_op
.
Scan
(
op
.
inputs
new_op
=
scan_op
.
Scan
(
op
.
inputs
,
,
op
.
outputs
op
.
outputs
,
,
info
)
info
)
return
new_op
.
make_node
(
*
inputs
)
.
outputs
return
new_op
.
make_node
(
*
inputs
)
.
outputs
return
False
return
False
optdb
.
register
(
'scanOp_make_inplace'
optdb
.
register
(
'scanOp_make_inplace'
,
,
opt
.
in2out
(
scan_make_inplace
,
ignore_newtrees
=
True
)
opt
.
in2out
(
scan_make_inplace
,
ignore_newtrees
=
True
),
,
75
75
,
,
'fast_run'
'fast_run'
,
,
'inplace'
'inplace'
,
,
'scan'
)
'scan'
)
class
ScanSaveMem
(
gof
.
Optimizer
):
class
ScanSaveMem
(
gof
.
Optimizer
):
...
@@ -329,24 +346,25 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -329,24 +346,25 @@ class ScanSaveMem(gof.Optimizer):
def
__init__
(
self
):
def
__init__
(
self
):
gof
.
Optimizer
.
__init__
(
self
)
gof
.
Optimizer
.
__init__
(
self
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
())
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
())
def
process_node
(
self
,
env
,
node
):
def
process_node
(
self
,
env
,
node
):
# helpful functions
# helpful functions
def
select_min
(
x
,
y
):
def
select_min
(
x
,
y
):
if
x
is
None
:
if
x
is
None
:
return
y
return
y
if
y
is
None
:
if
y
is
None
:
return
x
return
x
return
tensor
.
minimum
(
x
,
y
)
return
tensor
.
minimum
(
x
,
y
)
def
select_max
(
x
,
y
):
def
select_max
(
x
,
y
):
if
x
is
None
:
if
x
is
None
:
return
y
return
y
if
y
is
None
:
if
y
is
None
:
return
x
return
x
return
tensor
.
maximum
(
x
,
y
)
return
tensor
.
maximum
(
x
,
y
)
def
sanitize
(
x
):
def
sanitize
(
x
):
if
x
is
None
:
if
x
is
None
:
...
@@ -367,9 +385,9 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -367,9 +385,9 @@ class ScanSaveMem(gof.Optimizer):
op
=
node
.
op
op
=
node
.
op
c_outs
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
c_outs
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
init_l
=
[
0
for
x
in
xrange
(
op
.
n_mit_mot
)]
init_l
=
[
0
for
x
in
xrange
(
op
.
n_mit_mot
)]
init_l
+=
[
abs
(
numpy
.
min
(
v
))
for
v
in
op
.
tap_array
[
op
.
n_mit_mot
:]
]
init_l
+=
[
abs
(
numpy
.
min
(
v
))
for
v
in
op
.
tap_array
[
op
.
n_mit_mot
:]
]
init_l
+=
[
0
for
x
in
xrange
(
op
.
n_nit_sot
)]
init_l
+=
[
0
for
x
in
xrange
(
op
.
n_nit_sot
)]
# 2. Check the clients of each output and see for how many steps
# 2. Check the clients of each output and see for how many steps
# does scan need to run
# does scan need to run
...
@@ -392,13 +410,13 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -392,13 +410,13 @@ class ScanSaveMem(gof.Optimizer):
# change the number of steps in that case. To do this we set
# change the number of steps in that case. To do this we set
# global_nsteps to None which is seen as a flag that nothing needs
# global_nsteps to None which is seen as a flag that nothing needs
# to be done
# to be done
if
len
(
node
.
outputs
)
<=
c_outs
:
if
len
(
node
.
outputs
)
<=
c_outs
:
global_nsteps
=
{
'real'
:
-
1
,
'sym'
:
[]}
global_nsteps
=
{
'real'
:
-
1
,
'sym'
:
[]}
else
:
else
:
global_nsteps
=
None
global_nsteps
=
None
# Keeps track of the original slices that each client represent
# Keeps track of the original slices that each client represent
slices
=
[
None
for
o
in
node
.
outputs
]
slices
=
[
None
for
o
in
node
.
outputs
]
# A list for each output indicating how many intermediate values
# A list for each output indicating how many intermediate values
# should be stored. If negative it means none of the intermediate
# should be stored. If negative it means none of the intermediate
...
@@ -409,31 +427,31 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -409,31 +427,31 @@ class ScanSaveMem(gof.Optimizer):
# Note that for mit_mot outputs and shared outputs we can not change
# Note that for mit_mot outputs and shared outputs we can not change
# the number of intermediate steps stored without affecting the
# the number of intermediate steps stored without affecting the
# result of the op
# result of the op
store_steps
=
[
0
for
o
in
xrange
(
op
.
n_mit_mot
)]
store_steps
=
[
0
for
o
in
xrange
(
op
.
n_mit_mot
)]
store_steps
+=
[
-
1
for
o
in
node
.
outputs
[
op
.
n_mit_mot
:
c_outs
]]
store_steps
+=
[
-
1
for
o
in
node
.
outputs
[
op
.
n_mit_mot
:
c_outs
]]
# Flag that says if an input has changed and we need to do something
# Flag that says if an input has changed and we need to do something
# or not
# or not
flag_store
=
False
flag_store
=
False
# 2.2 Loop over the clients
# 2.2 Loop over the clients
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
# look at all its clients
# look at all its clients
slices
[
i
]
=
[]
slices
[
i
]
=
[]
for
cl
,
_
in
out
.
clients
:
for
cl
,
_
in
out
.
clients
:
# 2.1 outputs of the function
# 2.1 outputs of the function
#=> output needs all its intermediate values
#=> output needs all its intermediate values
if
type
(
cl
)
==
str
:
if
type
(
cl
)
==
str
:
# if the node is actually an output, then
# if the node is actually an output, then
# we need to store the entire thing
# we need to store the entire thing
global_nsteps
=
None
global_nsteps
=
None
slices
[
i
]
=
None
slices
[
i
]
=
None
break
break
# 2.2 non-subtensor nodes
# 2.2 non-subtensor nodes
#=> output needs all its intermediate values
#=> output needs all its intermediate values
elif
not
isinstance
(
cl
.
op
,
tensor
.
basic
.
Subtensor
):
elif
not
isinstance
(
cl
.
op
,
tensor
.
basic
.
Subtensor
):
global_nsteps
=
None
global_nsteps
=
None
slices
[
i
]
=
None
slices
[
i
]
=
None
break
break
# 2.3 subtensor nodes
# 2.3 subtensor nodes
#=> output might need to store just a subset of its values
#=> output might need to store just a subset of its values
...
@@ -444,13 +462,11 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -444,13 +462,11 @@ class ScanSaveMem(gof.Optimizer):
if
this_slice
==
None
:
if
this_slice
==
None
:
# if unable to extract idx_list
# if unable to extract idx_list
#=> outputs needs all its intermediate values
#=> outputs needs all its intermediate values
global_nsteps
=
None
global_nsteps
=
None
slices
[
i
]
=
None
slices
[
i
]
=
None
break
break
# 2.3.2 extract the begin/end of the first dimension
# 2.3.2 extract the begin/end of the first dimension
if
i
>
op
.
n_mit_mot
:
if
i
>
op
.
n_mit_mot
:
try
:
try
:
length
=
shape_of
[
out
][
0
]
length
=
shape_of
[
out
][
0
]
...
@@ -463,26 +479,27 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -463,26 +479,27 @@ class ScanSaveMem(gof.Optimizer):
length
=
out
.
shape
[
0
]
length
=
out
.
shape
[
0
]
cf_slice
=
tensor
.
basic
.
get_canonical_form_slice
(
cf_slice
=
tensor
.
basic
.
get_canonical_form_slice
(
this_slice
[
0
],
length
)
this_slice
[
0
],
length
)
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
this_slice
[
0
]
.
stop
is
None
):
this_slice
[
0
]
.
stop
is
None
):
global_nsteps
=
None
global_nsteps
=
None
break
break
if
isinstance
(
cf_slice
[
0
],
slice
):
if
isinstance
(
cf_slice
[
0
],
slice
):
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
stop
)
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
stop
)
else
:
else
:
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
])
+
1
stop
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
])
+
1
if
stop
==
sys
.
maxint
or
stop
==
length
:
if
stop
==
sys
.
maxint
or
stop
==
length
:
stop
=
None
stop
=
None
else
:
else
:
# there is a **gotcha** here ! Namely, scan returns an
# there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output as
# array that contains the initial state of the output
# well. Which means that if have a initial state of
# as well. Which means that if have a initial state of
# length 3, and you look for 5 steps you get an output y
# length 3, and you look for 5 steps you get an output
# of length 8. If you only use y[:5], this does not mean
# y of length 8. If you only use y[:5], this does not
# that you only need to loop for 5 steps but actually
# mean that you only need to loop for 5 steps but
# only for 2 steps ( the first 3 are the initial state)
# actually only for 2 steps ( the first 3 are the
# initial state)
stop
=
stop
-
init_l
[
i
]
stop
=
stop
-
init_l
[
i
]
# 2.3.3 we might get away with less number of steps
# 2.3.3 we might get away with less number of steps
...
@@ -494,10 +511,11 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -494,10 +511,11 @@ class ScanSaveMem(gof.Optimizer):
elif
(
type
(
stop
)
is
int
and
stop
==
sys
.
maxint
):
elif
(
type
(
stop
)
is
int
and
stop
==
sys
.
maxint
):
global_nsteps
=
None
global_nsteps
=
None
# yes if it is a int k, 0 < k < maxint
# yes if it is a int k, 0 < k < maxint
elif
(
type
(
stop
)
is
int
and
global_nsteps
[
'real'
]
<
stop
):
elif
(
type
(
stop
)
is
int
and
global_nsteps
[
'real'
]
<
stop
):
global_nsteps
[
'real'
]
=
stop
global_nsteps
[
'real'
]
=
stop
# yes if it is a int k, 0 < k < maxint
# yes if it is a int k, 0 < k < maxint
elif
(
type
(
stop
)
is
int
and
stop
>
0
):
elif
(
type
(
stop
)
is
int
and
stop
>
0
):
pass
pass
# not otherwise
# not otherwise
else
:
else
:
...
@@ -510,10 +528,10 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -510,10 +528,10 @@ class ScanSaveMem(gof.Optimizer):
# there are some symbolic tensors that limit the number of
# there are some symbolic tensors that limit the number of
# steps
# steps
if
len
(
global_nsteps
[
'sym'
])
==
0
:
if
len
(
global_nsteps
[
'sym'
])
==
0
:
sym_steps
=
None
sym_steps
=
None
else
:
else
:
sym_steps
=
global_nsteps
[
'sym'
][
0
]
sym_steps
=
global_nsteps
[
'sym'
][
0
]
for
c
in
global_nsteps
[
'sym'
][
1
:]:
for
c
in
global_nsteps
[
'sym'
][
1
:]:
sym_steps
=
tensor
.
maximum
(
sym_steps
,
c
)
sym_steps
=
tensor
.
maximum
(
sym_steps
,
c
)
...
@@ -527,12 +545,11 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -527,12 +545,11 @@ class ScanSaveMem(gof.Optimizer):
nw_steps
=
node
.
inputs
[
0
]
nw_steps
=
node
.
inputs
[
0
]
global_nsteps
=
None
global_nsteps
=
None
# 2.4 Loop over the clients again now looking just to see how many
# 2.4 Loop over the clients again now looking just to see how many
# intermediate steps to store
# intermediate steps to store
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
# look at all its clients
# look at all its clients
for
cl
,
_
in
out
.
clients
:
for
cl
,
_
in
out
.
clients
:
if
type
(
cl
)
==
str
:
if
type
(
cl
)
==
str
:
store_steps
[
i
]
=
0
store_steps
[
i
]
=
0
break
break
...
@@ -546,7 +563,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -546,7 +563,7 @@ class ScanSaveMem(gof.Optimizer):
store_steps
[
i
]
=
0
store_steps
[
i
]
=
0
break
break
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
if
(
isinstance
(
this_slice
[
0
],
slice
)
and
this_slice
[
0
]
.
start
is
None
):
this_slice
[
0
]
.
start
is
None
):
store_steps
[
i
]
=
0
store_steps
[
i
]
=
0
break
break
...
@@ -559,46 +576,48 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -559,46 +576,48 @@ class ScanSaveMem(gof.Optimizer):
except
Exception
:
except
Exception
:
length
=
out
.
shape
[
0
]
length
=
out
.
shape
[
0
]
cf_slice
=
tensor
.
basic
.
get_canonical_form_slice
(
cf_slice
=
tensor
.
basic
.
get_canonical_form_slice
(
this_slice
[
0
],
length
)
this_slice
[
0
],
length
)
if
isinstance
(
cf_slice
[
0
],
slice
):
if
isinstance
(
cf_slice
[
0
],
slice
):
start
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
start
)
start
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
]
.
start
)
else
:
else
:
start
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
])
start
=
tensor
.
basic
.
extract_constant
(
cf_slice
[
0
])
if
start
==
0
or
store_steps
[
i
]
==
0
:
if
start
==
0
or
store_steps
[
i
]
==
0
:
store_steps
[
i
]
=
0
store_steps
[
i
]
=
0
else
:
else
:
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
])
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
])
if
store_steps
[
i
]
!=
-
1
:
if
store_steps
[
i
]
!=
-
1
:
pval
=
select_max
(
pval
,
store_steps
[
i
])
pval
=
select_max
(
pval
,
store_steps
[
i
])
store_steps
[
i
]
=
pval
store_steps
[
i
]
=
pval
flag_store
=
True
flag_store
=
True
orphane_outs
=
[
i
for
i
,
x
in
enumerate
(
store_steps
)
orphane_outs
=
[
i
for
i
,
x
in
enumerate
(
store_steps
)
if
(
type
(
x
)
is
int
)
and
(
x
<
0
)
]
if
(
type
(
x
)
is
int
)
and
(
x
<
0
)
]
flag_store
=
flag_store
or
(
len
(
orphane_outs
)
>
0
)
flag_store
=
flag_store
or
(
len
(
orphane_outs
)
>
0
)
# 3. is there anything to change ?
# 3. is there anything to change ?
if
(
flag_store
or
global_nsteps
is
not
None
):
if
(
flag_store
or
global_nsteps
is
not
None
):
# 3.1 initialize inputs for the new scan
# 3.1 initialize inputs for the new scan
old_outputs
=
[]
old_outputs
=
[]
nw_inputs
=
list
(
node
.
inputs
)
nw_inputs
=
list
(
node
.
inputs
)
nw_inputs
[
0
]
=
nw_steps
nw_inputs
[
0
]
=
nw_steps
# 3.2 check orphane outputs to see if we can eliminate any
# 3.2 check orphane outputs to see if we can eliminate any
required
,
not_required
=
\
required
,
not_required
=
\
scan_utils
.
scan_can_remove_outs
(
node
.
op
scan_utils
.
scan_can_remove_outs
(
node
.
op
,
,
orphane_outs
)
orphane_outs
)
# 3.3. compose replace pairs for those nodes that need not
# 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required
# to store everything in memory ( or ar orphane and required
# by the inner function .. )
# by the inner function .. )
replaced_outs
=
[]
replaced_outs
=
[]
offset
=
1
+
op
.
n_seqs
+
op
.
n_mit_mot
offset
=
1
+
op
.
n_seqs
+
op
.
n_mit_mot
for
idx
,
_val
in
enumerate
(
store_steps
[
op
.
n_mit_mot
:]):
for
idx
,
_val
in
enumerate
(
store_steps
[
op
.
n_mit_mot
:]):
i
=
idx
+
op
.
n_mit_mot
i
=
idx
+
op
.
n_mit_mot
if
not
(
type
(
_val
)
is
int
and
_val
<=
0
and
i
not
in
required
):
if
not
(
type
(
_val
)
is
int
and
_val
<=
0
and
i
not
in
required
):
if
idx
+
op
.
n_mit_mot
in
required
:
if
idx
+
op
.
n_mit_mot
in
required
:
val
=
1
val
=
1
else
:
else
:
val
=
_val
val
=
_val
...
@@ -610,21 +629,21 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -610,21 +629,21 @@ class ScanSaveMem(gof.Optimizer):
# a) the input is a set_subtensor, in that case we
# a) the input is a set_subtensor, in that case we
# can replace the initial tensor by a slice,
# can replace the initial tensor by a slice,
# b) it is not, and we simply take a slice of it.
# b) it is not, and we simply take a slice of it.
if
(
nw_inputs
[
offset
+
idx
]
.
owner
and
if
(
nw_inputs
[
offset
+
idx
]
.
owner
and
isinstance
(
nw_inputs
[
offset
+
idx
]
.
owner
.
op
,
isinstance
(
nw_inputs
[
offset
+
idx
]
.
owner
.
op
,
tensor
.
IncSubtensor
)):
tensor
.
IncSubtensor
)):
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tensor
.
as_tensor_variable
(
val
-
init_l
[
i
]))
tensor
.
as_tensor_variable
(
val
-
init_l
[
i
]))
tmp
=
pre_constant_merge
([
tmp
])[
0
]
tmp
=
pre_constant_merge
([
tmp
])[
0
]
nw_input
=
scan_utils
.
expand
(
_nw_input
,
tmp
)
nw_input
=
scan_utils
.
expand
(
_nw_input
,
tmp
)
else
:
else
:
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tensor
.
as_tensor_variable
(
val
))
tensor
.
as_tensor_variable
(
val
))
tmp
=
pre_constant_merge
([
tmp
])[
0
]
tmp
=
pre_constant_merge
([
tmp
])[
0
]
nw_input
=
nw_inputs
[
offset
+
idx
][:
tmp
]
nw_input
=
nw_inputs
[
offset
+
idx
][:
tmp
]
nw_inputs
[
offset
+
idx
]
=
nw_input
nw_inputs
[
offset
+
idx
]
=
nw_input
replaced_outs
.
append
(
op
.
n_mit_mot
+
idx
)
replaced_outs
.
append
(
op
.
n_mit_mot
+
idx
)
odx
=
op
.
n_mit_mot
+
idx
odx
=
op
.
n_mit_mot
+
idx
old_outputs
+=
[(
odx
,
[
x
[
0
]
.
outputs
[
0
]
for
x
in
old_outputs
+=
[(
odx
,
[
x
[
0
]
.
outputs
[
0
]
for
x
in
...
@@ -632,8 +651,8 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -632,8 +651,8 @@ class ScanSaveMem(gof.Optimizer):
# If there is no memory pre-allocated for this output
# If there is no memory pre-allocated for this output
elif
idx
<
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
:
elif
idx
<
op
.
n_mit_sot
+
op
.
n_sit_sot
+
op
.
n_nit_sot
:
pos
=
(
op
.
n_mit_mot
+
idx
+
op
.
n_seqs
pos
=
(
op
.
n_mit_mot
+
idx
+
op
.
n_seqs
+
+
1
+
op
.
n_shared_outs
)
1
+
op
.
n_shared_outs
)
if
nw_inputs
[
pos
]
==
node
.
inputs
[
0
]:
if
nw_inputs
[
pos
]
==
node
.
inputs
[
0
]:
nw_inputs
[
pos
]
=
val
nw_inputs
[
pos
]
=
val
odx
=
op
.
n_mit_mot
+
idx
odx
=
op
.
n_mit_mot
+
idx
...
@@ -646,43 +665,41 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -646,43 +665,41 @@ class ScanSaveMem(gof.Optimizer):
for
idx
,
val
in
enumerate
(
store_steps
[
op
.
n_mit_mot
:]):
for
idx
,
val
in
enumerate
(
store_steps
[
op
.
n_mit_mot
:]):
if
val
==
0
:
if
val
==
0
:
if
idx
<
op
.
n_mit_sot
+
op
.
n_sit_sot
:
if
idx
<
op
.
n_mit_sot
+
op
.
n_sit_sot
:
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
odx
=
op
.
n_mit_mot
+
idx
odx
=
op
.
n_mit_mot
+
idx
nw_input
=
scan_utils
.
expand
(
_nw_input
,
nw_steps
)
nw_input
=
scan_utils
.
expand
(
_nw_input
,
nw_steps
)
nw_inputs
[
offset
+
idx
]
=
nw_input
nw_inputs
[
offset
+
idx
]
=
nw_input
elif
idx
<
(
op
.
n_mit_sot
+
op
.
n_sit_sot
+
elif
idx
<
(
op
.
n_mit_sot
+
op
.
n_sit_sot
+
+
op
.
n_nit_sot
):
op
.
n_nit_sot
):
in_idx
=
offset
+
idx
+
op
.
n_shared_outs
in_idx
=
offset
+
idx
+
op
.
n_shared_outs
if
nw_inputs
[
in_idx
]
==
node
.
inputs
[
0
]:
if
nw_inputs
[
in_idx
]
==
node
.
inputs
[
0
]:
nw_inputs
[
in_idx
]
=
nw_steps
nw_inputs
[
in_idx
]
=
nw_steps
odx
=
op
.
n_mit_mot
+
idx
odx
=
op
.
n_mit_mot
+
idx
# 3.5 Remove unwanted orphane outputs
# 3.5 Remove unwanted orphane outputs
(
inps
,
outs
,
info
,
node_ins
,
compress_map
)
=
\
(
inps
,
outs
,
info
,
node_ins
,
compress_map
)
=
\
scan_utils
.
compress_outs
(
op
,
not_required
,
nw_inputs
)
scan_utils
.
compress_outs
(
op
,
not_required
,
nw_inputs
)
inv_compress_map
=
{}
inv_compress_map
=
{}
for
k
,
v
in
compress_map
.
items
():
for
k
,
v
in
compress_map
.
items
():
inv_compress_map
[
v
]
=
k
inv_compress_map
[
v
]
=
k
node_ins
=
[
pre_greedy_local_optimizer
(
list_opt_slice
,
x
)
for
x
in
node_ins
=
[
pre_greedy_local_optimizer
(
list_opt_slice
,
x
)
for
x
in
node_ins
]
node_ins
]
node_ins
=
pre_constant_merge
(
node_ins
)
node_ins
=
pre_constant_merge
(
node_ins
)
# 3.6 Compose the new scan
# 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
# twice since bad things usually happen if I do that
info
[
'_scan_merge_visited'
]
=
True
info
[
'_scan_merge_visited'
]
=
True
new_outs
=
scan_op
.
Scan
(
inps
new_outs
=
scan_op
.
Scan
(
inps
,
,
outs
outs
,
,
info
)
.
make_node
(
*
node_ins
)
.
outputs
info
)
.
make_node
(
*
node_ins
)
.
outputs
old_new
=
[]
old_new
=
[]
# 3.7 Get replace pairs for those outputs that do not change
# 3.7 Get replace pairs for those outputs that do not change
# the number of intermediate steps stored
# the number of intermediate steps stored
for
idx
,
sl
in
enumerate
(
slices
):
for
idx
,
sl
in
enumerate
(
slices
):
if
global_nsteps
and
sl
is
not
None
and
store_steps
[
idx
]
==
0
:
if
global_nsteps
and
sl
is
not
None
and
store_steps
[
idx
]
==
0
:
for
hdx
,
cl
in
enumerate
(
node
.
outputs
[
idx
]
.
clients
):
for
hdx
,
cl
in
enumerate
(
node
.
outputs
[
idx
]
.
clients
):
cnf_slice
,
old_slices
=
sl
[
hdx
]
cnf_slice
,
old_slices
=
sl
[
hdx
]
# Sanitize the nw_slice by converting ints back into
# Sanitize the nw_slice by converting ints back into
# constants :) I only need to do this for the first
# constants :) I only need to do this for the first
...
@@ -697,18 +714,16 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -697,18 +714,16 @@ class ScanSaveMem(gof.Optimizer):
else
:
else
:
fslice
=
sanitize
(
cnf_slice
[
0
])
fslice
=
sanitize
(
cnf_slice
[
0
])
nw_slice
=
(
fslice
,)
+
tuple
(
old_slices
[
1
:])
nw_slice
=
(
fslice
,)
+
tuple
(
old_slices
[
1
:])
nw_pos
=
inv_compress_map
[
idx
]
nw_pos
=
inv_compress_map
[
idx
]
nw_out
=
new_outs
[
nw_pos
]
nw_out
=
new_outs
[
nw_pos
]
subtens
=
tensor
.
basic
.
Subtensor
(
nw_slice
)
subtens
=
tensor
.
basic
.
Subtensor
(
nw_slice
)
# slice inputs
# slice inputs
sl_ins
=
tensor
.
basic
.
Subtensor
.
collapse
(
sl_ins
=
tensor
.
basic
.
Subtensor
.
collapse
(
nw_slice
nw_slice
,
,
lambda
entry
:
isinstance
(
entry
lambda
entry
:
isinstance
(
entry
,
,
tensor
.
Variable
))
tensor
.
Variable
))
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
*
sl_ins
)
.
outputs
[
0
]
*
sl_ins
)
.
outputs
[
0
]
if
new_o
.
ndim
>
0
:
if
new_o
.
ndim
>
0
:
...
@@ -721,34 +736,35 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -721,34 +736,35 @@ class ScanSaveMem(gof.Optimizer):
if
len
(
old_outs
)
>
0
:
if
len
(
old_outs
)
>
0
:
nw_pos
=
compress_map
[
pos
]
nw_pos
=
compress_map
[
pos
]
nw_out
=
new_outs
[
nw_pos
]
nw_out
=
new_outs
[
nw_pos
]
for
k
,
old
in
enumerate
(
old_outs
):
for
k
,
old
in
enumerate
(
old_outs
):
# Get the correct slice
# Get the correct slice
cnf_slice
,
old_slices
=
slices
[
pos
][
k
]
cnf_slice
,
old_slices
=
slices
[
pos
][
k
]
if
type
(
cnf_slice
[
0
])
is
slice
:
if
type
(
cnf_slice
[
0
])
is
slice
:
start
=
(
cnf_slice
[
0
]
.
start
-
nw_steps
-
start
=
(
cnf_slice
[
0
]
.
start
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
]
)
init_l
[
pos
]
+
store_steps
[
pos
])
if
(
cnf_slice
[
0
]
.
stop
is
not
None
and
if
(
cnf_slice
[
0
]
.
stop
is
not
None
and
cnf_slice
[
0
]
.
stop
!=
sys
.
maxint
):
cnf_slice
[
0
]
.
stop
!=
sys
.
maxint
):
stop
=
(
cnf_slice
[
0
]
.
stop
-
nw_steps
-
stop
=
(
cnf_slice
[
0
]
.
stop
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
])
init_l
[
pos
]
+
store_steps
[
pos
])
else
:
else
:
stop
=
None
stop
=
None
nw_slice
=
(
(
slice
(
sanitize
(
start
),
nw_slice
=
((
slice
(
sanitize
(
start
),
sanitize
(
stop
),
sanitize
(
stop
),
sanitize
(
cnf_slice
[
0
]
.
step
)),)
+
sanitize
(
cnf_slice
[
0
]
.
step
)),)
tuple
(
old_slices
[
1
:])
)
+
tuple
(
old_slices
[
1
:])
)
else
:
else
:
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
]
)
init_l
[
pos
]
+
store_steps
[
pos
]
)
nw_slice
=
(
sanitize
(
position
),)
+
tuple
(
old_slices
[
1
:])
nw_slice
=
(
sanitize
(
position
),)
+
\
tuple
(
old_slices
[
1
:])
subtens
=
tensor
.
basic
.
Subtensor
(
nw_slice
)
subtens
=
tensor
.
basic
.
Subtensor
(
nw_slice
)
sl_ins
=
tensor
.
basic
.
Subtensor
.
collapse
(
sl_ins
=
tensor
.
basic
.
Subtensor
.
collapse
(
nw_slice
nw_slice
,
,
lambda
entry
:
isinstance
(
entry
lambda
entry
:
isinstance
(
entry
,
,
tensor
.
Variable
))
tensor
.
Variable
))
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
new_o
=
subtens
.
make_node
(
new_outs
[
nw_pos
],
*
sl_ins
)
.
outputs
[
0
]
*
sl_ins
)
.
outputs
[
0
]
if
new_o
.
ndim
>
0
:
if
new_o
.
ndim
>
0
:
...
@@ -757,13 +773,12 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -757,13 +773,12 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes
# 3.9. Get replace pairs for all other nodes
if
flag_store
or
global_nsteps
is
not
None
:
if
flag_store
or
global_nsteps
is
not
None
:
for
idx
,
o
in
enumerate
(
node
.
outputs
):
for
idx
,
o
in
enumerate
(
node
.
outputs
):
if
not
(
idx
in
replaced_outs
)
and
not
idx
in
not_required
:
if
not
(
idx
in
replaced_outs
)
and
not
idx
in
not_required
:
nw_pos
=
compress_map
[
idx
]
nw_pos
=
compress_map
[
idx
]
old_new
+=
[(
o
,
new_outs
[
nw_pos
])]
old_new
+=
[(
o
,
new_outs
[
nw_pos
])]
env
.
replace_all_validate
(
old_new
,
reason
=
'scan_save_mem'
)
env
.
replace_all_validate
(
old_new
,
reason
=
'scan_save_mem'
)
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
...
@@ -776,16 +791,16 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -776,16 +791,16 @@ class ScanSaveMem(gof.Optimizer):
# Just before specialize to have the other optimization
# Just before specialize to have the other optimization
# like constant folding being applied
# like constant folding being applied
# This don't introduce inplace.
# This don't introduce inplace.
scan_seqopt
.
register
(
'scanOp_save_mem'
,
scan_seqopt
.
register
(
'scanOp_save_mem'
,
ScanSaveMem
(),
ScanSaveMem
(),
4
,
4
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
class
ScanMerge
(
gof
.
Optimizer
):
class
ScanMerge
(
gof
.
Optimizer
):
""" Graph Optimizer that merges different scan ops """
""" Graph Optimizer that merges different scan ops """
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
())
env
.
extend
(
gof
.
toolbox
.
ReplaceValidate
())
def
merge
(
self
,
nodes
):
def
merge
(
self
,
nodes
):
...
@@ -796,29 +811,26 @@ class ScanMerge(gof.Optimizer):
...
@@ -796,29 +811,26 @@ class ScanMerge(gof.Optimizer):
else
:
else
:
as_while
=
False
as_while
=
False
info
=
{}
info
=
{}
info
[
'tap_array'
]
=
[]
info
[
'n_seqs'
]
=
sum
([
nd
.
op
.
n_seqs
for
nd
in
nodes
])
info
[
'tap_array'
]
=
[]
info
[
'n_mit_mot'
]
=
sum
([
nd
.
op
.
n_mit_mot
for
nd
in
nodes
])
info
[
'n_seqs'
]
=
sum
([
nd
.
op
.
n_seqs
for
nd
in
nodes
])
info
[
'n_mit_mot_outs'
]
=
sum
([
nd
.
op
.
n_mit_mot_outs
for
nd
in
nodes
])
info
[
'n_mit_mot'
]
=
sum
([
nd
.
op
.
n_mit_mot
for
nd
in
nodes
])
info
[
'n_mit_mot_outs'
]
=
sum
([
nd
.
op
.
n_mit_mot_outs
for
nd
in
nodes
])
info
[
'mit_mot_out_slices'
]
=
[]
info
[
'mit_mot_out_slices'
]
=
[]
info
[
'n_mit_sot'
]
=
sum
([
nd
.
op
.
n_mit_sot
for
nd
in
nodes
])
info
[
'n_mit_sot'
]
=
sum
([
nd
.
op
.
n_mit_sot
for
nd
in
nodes
])
info
[
'n_sit_sot'
]
=
sum
([
nd
.
op
.
n_sit_sot
for
nd
in
nodes
])
info
[
'n_sit_sot'
]
=
sum
([
nd
.
op
.
n_sit_sot
for
nd
in
nodes
])
info
[
'n_shared_outs'
]
=
sum
([
nd
.
op
.
n_shared_outs
for
nd
in
nodes
])
info
[
'n_shared_outs'
]
=
sum
([
nd
.
op
.
n_shared_outs
for
nd
in
nodes
])
info
[
'n_nit_sot'
]
=
sum
([
nd
.
op
.
n_nit_sot
for
nd
in
nodes
])
info
[
'n_nit_sot'
]
=
sum
([
nd
.
op
.
n_nit_sot
for
nd
in
nodes
])
info
[
'truncate_gradient'
]
=
nodes
[
0
]
.
op
.
truncate_gradient
info
[
'truncate_gradient'
]
=
nodes
[
0
]
.
op
.
truncate_gradient
info
[
'name'
]
=
'&'
.
join
([
nd
.
op
.
name
for
nd
in
nodes
])
info
[
'name'
]
=
'&'
.
join
([
nd
.
op
.
name
for
nd
in
nodes
])
info
[
'mode'
]
=
nodes
[
0
]
.
op
.
mode
info
[
'mode'
]
=
nodes
[
0
]
.
op
.
mode
info
[
'inplace'
]
=
False
info
[
'inplace'
]
=
False
info
[
'gpu'
]
=
False
info
[
'gpu'
]
=
False
info
[
'as_while'
]
=
as_while
info
[
'as_while'
]
=
as_while
info
[
'profile'
]
=
nodes
[
0
]
.
op
.
profile
info
[
'profile'
]
=
nodes
[
0
]
.
op
.
profile
inner_ins
=
[]
inner_ins
=
[]
outer_ins
=
[]
outer_ins
=
[]
inner_outs
=
[]
inner_outs
=
[]
outer_outs
=
[]
outer_outs
=
[]
...
@@ -828,57 +840,56 @@ class ScanMerge(gof.Optimizer):
...
@@ -828,57 +840,56 @@ class ScanMerge(gof.Optimizer):
k
.
name
+=
str
(
suffix
)
k
.
name
+=
str
(
suffix
)
return
ls
return
ls
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Seq
# Seq
inner_ins
+=
rename
(
nd
.
op
.
inner_seqs
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_seqs
(),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_seqs
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_seqs
(
nd
),
idx
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# MitMot
# MitMot
inner_ins
+=
rename
(
nd
.
op
.
inner_mitmot
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_mitmot
(),
idx
)
inner_outs
+=
nd
.
op
.
inner_mitmot_outs
()
inner_outs
+=
nd
.
op
.
inner_mitmot_outs
()
info
[
'tap_array'
]
+=
nd
.
op
.
mitmot_taps
()
info
[
'tap_array'
]
+=
nd
.
op
.
mitmot_taps
()
info
[
'mit_mot_out_slices'
]
+=
nd
.
op
.
mitmot_out_taps
()
info
[
'mit_mot_out_slices'
]
+=
nd
.
op
.
mitmot_out_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitmot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_mitmot
(
nd
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitmot_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_mitmot_outs
(
nd
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# MitSot
# MitSot
inner_ins
+=
rename
(
nd
.
op
.
inner_mitsot
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_mitsot
(),
idx
)
inner_outs
+=
nd
.
op
.
inner_mitsot_outs
()
inner_outs
+=
nd
.
op
.
inner_mitsot_outs
()
info
[
'tap_array'
]
+=
nd
.
op
.
mitsot_taps
()
info
[
'tap_array'
]
+=
nd
.
op
.
mitsot_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitsot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_mitsot
(
nd
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitsot_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_mitsot_outs
(
nd
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# SitSot
# SitSot
inner_ins
+=
rename
(
nd
.
op
.
inner_sitsot
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_sitsot
(),
idx
)
info
[
'tap_array'
]
+=
[[
-
1
]
for
x
in
xrange
(
nd
.
op
.
n_sit_sot
)]
info
[
'tap_array'
]
+=
[[
-
1
]
for
x
in
xrange
(
nd
.
op
.
n_sit_sot
)]
inner_outs
+=
nd
.
op
.
inner_sitsot_outs
()
inner_outs
+=
nd
.
op
.
inner_sitsot_outs
()
outer_ins
+=
rename
(
nd
.
op
.
outer_sitsot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_sitsot
(
nd
),
idx
)
outer_outs
+=
nd
.
op
.
outer_sitsot_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_sitsot_outs
(
nd
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Shared
# Shared
inner_ins
+=
rename
(
nd
.
op
.
inner_shared
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_shared
(),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_shared
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_shared
(
nd
),
idx
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# NitSot
# NitSot
inner_outs
+=
nd
.
op
.
inner_nitsot_outs
()
inner_outs
+=
nd
.
op
.
inner_nitsot_outs
()
outer_ins
+=
rename
(
nd
.
op
.
outer_nitsot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_nitsot
(
nd
),
idx
)
outer_outs
+=
nd
.
op
.
outer_nitsot_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_nitsot_outs
(
nd
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Shared
# Shared
outer_outs
+=
nd
.
op
.
outer_shared_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_shared_outs
(
nd
)
inner_outs
+=
nd
.
op
.
inner_shared_outs
()
inner_outs
+=
nd
.
op
.
inner_shared_outs
()
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Non Seqs
# Non Seqs
inner_ins
+=
rename
(
nd
.
op
.
inner_non_seqs
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_non_seqs
(),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_non_seqs
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_non_seqs
(
nd
),
idx
)
# Add back the number of steps
# Add back the number of steps
outer_ins
=
[
nodes
[
0
]
.
inputs
[
0
]]
+
outer_ins
outer_ins
=
[
nodes
[
0
]
.
inputs
[
0
]]
+
outer_ins
...
@@ -897,8 +908,6 @@ class ScanMerge(gof.Optimizer):
...
@@ -897,8 +908,6 @@ class ScanMerge(gof.Optimizer):
return
zip
(
outer_outs
,
new_outs
)
return
zip
(
outer_outs
,
new_outs
)
def
belongs_to_set
(
self
,
node
,
set_nodes
):
def
belongs_to_set
(
self
,
node
,
set_nodes
):
"""
"""
This function checks if node `node` belongs to `set_nodes`, in the
This function checks if node `node` belongs to `set_nodes`, in the
...
@@ -918,7 +927,6 @@ class ScanMerge(gof.Optimizer):
...
@@ -918,7 +927,6 @@ class ScanMerge(gof.Optimizer):
except
TypeError
:
except
TypeError
:
pass
pass
rep_nsteps
=
rep
.
inputs
[
0
]
rep_nsteps
=
rep
.
inputs
[
0
]
try
:
try
:
rep_nsteps
=
int
(
get_constant_value
(
rep_nsteps
))
rep_nsteps
=
int
(
get_constant_value
(
rep_nsteps
))
...
@@ -943,11 +951,9 @@ class ScanMerge(gof.Optimizer):
...
@@ -943,11 +951,9 @@ class ScanMerge(gof.Optimizer):
rep
.
op
.
inputs
)
rep
.
op
.
inputs
)
return
same_cond
and
(
nsteps
==
rep_nsteps
)
and
can_add
return
same_cond
and
(
nsteps
==
rep_nsteps
)
and
can_add
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
# Collect all scan nodes ordered according to toposort
# Collect all scan nodes ordered according to toposort
scan_nodes
=
[
nd
for
nd
in
env
.
toposort
()
scan_nodes
=
[
nd
for
nd
in
env
.
toposort
()
if
isinstance
(
nd
.
op
,
scan_op
.
Scan
)]
if
isinstance
(
nd
.
op
,
scan_op
.
Scan
)]
# All sets of possibly mergeable nodes
# All sets of possibly mergeable nodes
...
@@ -955,7 +961,7 @@ class ScanMerge(gof.Optimizer):
...
@@ -955,7 +961,7 @@ class ScanMerge(gof.Optimizer):
for
nd
in
scan_nodes
:
for
nd
in
scan_nodes
:
belongs_to_set_idx
=
-
1
belongs_to_set_idx
=
-
1
for
pos
,
subset
in
enumerate
(
all_sets
):
for
pos
,
subset
in
enumerate
(
all_sets
):
if
self
.
belongs_to_set
(
nd
,
subset
):
if
self
.
belongs_to_set
(
nd
,
subset
):
assert
belongs_to_set_idx
==
-
1
assert
belongs_to_set_idx
==
-
1
belongs_to_set_idx
=
pos
belongs_to_set_idx
=
pos
...
@@ -968,7 +974,7 @@ class ScanMerge(gof.Optimizer):
...
@@ -968,7 +974,7 @@ class ScanMerge(gof.Optimizer):
for
subset
in
all_sets
:
for
subset
in
all_sets
:
if
len
(
subset
)
>
1
:
if
len
(
subset
)
>
1
:
proposal
=
self
.
merge
(
subset
)
proposal
=
self
.
merge
(
subset
)
env
.
replace_all_validate
(
proposal
,
reason
=
'scan_merge'
)
env
.
replace_all_validate
(
proposal
,
reason
=
'scan_merge'
)
# after const merge but before stabilize so that we can have identity
# after const merge but before stabilize so that we can have identity
...
@@ -980,23 +986,27 @@ scan_seqopt.register('scanOp_merge',
...
@@ -980,23 +986,27 @@ scan_seqopt.register('scanOp_merge',
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
def
has_duplicates
(
l
):
def
has_duplicates
(
l
):
"""returns true if l has any duplicates (according to __eq__)."""
"""returns true if l has any duplicates (according to __eq__)."""
return
len
(
set
(
l
))
<
len
(
l
)
return
len
(
set
(
l
))
<
len
(
l
)
def
make_equiv
(
lo
,
li
):
def
make_equiv
(
lo
,
li
):
"""builds a dictionary of equivalences between inner inputs based on the equivalence of their corresponding outer inputs."""
"""builds a dictionary of equivalences between inner inputs based on
the equivalence of their corresponding outer inputs."""
seeno
=
{}
seeno
=
{}
left
=
[]
left
=
[]
right
=
[]
right
=
[]
for
o
,
i
in
zip
(
lo
,
li
):
for
o
,
i
in
zip
(
lo
,
li
):
if
o
in
seeno
:
if
o
in
seeno
:
left
+=
[
i
]
left
+=
[
i
]
right
+=
[
o
]
right
+=
[
o
]
else
:
else
:
seeno
[
o
]
=
i
seeno
[
o
]
=
i
return
left
,
right
return
left
,
right
@gof.local_optimizer
([
None
])
@gof.local_optimizer
([
None
])
def
scan_merge_inouts
(
node
):
def
scan_merge_inouts
(
node
):
if
not
isinstance
(
node
.
op
,
scan_op
.
Scan
):
if
not
isinstance
(
node
.
op
,
scan_op
.
Scan
):
...
@@ -1056,58 +1066,68 @@ def scan_merge_inouts(node):
...
@@ -1056,58 +1066,68 @@ def scan_merge_inouts(node):
na
=
a
na
=
a
# start again
# start again
left
=
[]
left
=
[]
right
=
[]
right
=
[]
if
has_duplicates
(
na
.
outer_in_shared
):
if
has_duplicates
(
na
.
outer_in_shared
):
_left
,
_right
=
make_equiv
(
na
.
outer_in_shared
,
na
.
inner_in_shared
)
_left
,
_right
=
make_equiv
(
na
.
outer_in_shared
,
na
.
inner_in_shared
)
left
+=
_left
left
+=
_left
right
+=
_right
right
+=
_right
if
has_duplicates
(
na
.
outer_in_sit_sot
):
if
has_duplicates
(
na
.
outer_in_sit_sot
):
_left
,
_right
=
make_equiv
(
na
.
outer_in_sit_sot
,
na
.
inner_in_sit_sot
)
_left
,
_right
=
make_equiv
(
na
.
outer_in_sit_sot
,
na
.
inner_in_sit_sot
)
left
+=
_left
left
+=
_left
right
+=
_right
right
+=
_right
if
has_duplicates
(
na
.
outer_in_mit_mot
):
if
has_duplicates
(
na
.
outer_in_mit_mot
):
seen
=
{}
seen
=
{}
for
omm
,
imm
,
_sl
in
zip
(
na
.
outer_in_mit_mot
,
na
.
inner_in_mit_mot
,
na
.
mit_mot_in_slices
):
for
omm
,
imm
,
_sl
in
zip
(
na
.
outer_in_mit_mot
,
na
.
inner_in_mit_mot
,
na
.
mit_mot_in_slices
):
sl
=
tuple
(
_sl
)
sl
=
tuple
(
_sl
)
if
(
omm
,
sl
)
in
seen
:
if
(
omm
,
sl
)
in
seen
:
simm
=
seen
[(
omm
,
sl
)]
simm
=
seen
[(
omm
,
sl
)]
left
+=
imm
left
+=
imm
right
+=
simm
right
+=
simm
else
:
else
:
seen
[(
omm
,
sl
)]
=
imm
seen
[(
omm
,
sl
)]
=
imm
if
has_duplicates
(
na
.
outer_in_mit_sot
):
if
has_duplicates
(
na
.
outer_in_mit_sot
):
seen
=
{}
seen
=
{}
for
oms
,
ims
,
_sl
in
zip
(
na
.
outer_in_mit_sot
,
na
.
inner_in_mit_sot
,
na
.
mit_sot_in_slices
):
for
oms
,
ims
,
_sl
in
zip
(
na
.
outer_in_mit_sot
,
na
.
inner_in_mit_sot
,
na
.
mit_sot_in_slices
):
sl
=
tuple
(
_sl
)
sl
=
tuple
(
_sl
)
if
(
oms
,
sl
)
in
seen
:
if
(
oms
,
sl
)
in
seen
:
sims
=
seen
[(
oms
,
sl
)]
sims
=
seen
[(
oms
,
sl
)]
left
+=
ims
left
+=
ims
right
+=
sims
right
+=
sims
else
:
else
:
seen
[(
oms
,
sl
)]
=
ims
seen
[(
oms
,
sl
)]
=
ims
def
map_out
(
i
,
o
,
seen
):
def
map_out
(
i
,
o
,
seen
):
for
si
,
so
in
seen
:
for
si
,
so
in
seen
:
if
equal_computations
([
i
],
[
si
],
left
,
right
):
if
equal_computations
([
i
],
[
si
],
left
,
right
):
return
so
return
so
seen
.
append
((
i
,
o
))
seen
.
append
((
i
,
o
))
return
o
return
o
seen
=
[]
seen
=
[]
na
.
outer_out_nit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_nit_sot
,
na
.
outer_out_nit_sot
)]
na
.
outer_out_nit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_nit_sot
,
na
.
outer_out_nit_sot
)]
seen
=
[]
seen
=
[]
na
.
outer_out_sit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_sit_sot
,
na
.
outer_out_sit_sot
)]
na
.
outer_out_sit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_sit_sot
,
na
.
outer_out_sit_sot
)]
seen
=
[]
seen
=
[]
na
.
outer_out_mit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_mit_sot
,
na
.
outer_out_mit_sot
)]
na
.
outer_out_mit_sot
=
[
map_out
(
i
,
o
,
seen
)
for
i
,
o
in
zip
(
na
.
inner_out_mit_sot
,
na
.
outer_out_mit_sot
)]
seen
=
[]
seen
=
[]
new_outer_out_mit_mot
=
[]
new_outer_out_mit_mot
=
[]
for
imm
,
omm
,
osl
in
zip
(
na
.
inner_out_mit_mot
,
na
.
outer_out_mit_mot
,
na
.
mit_mot_out_slices
):
for
imm
,
omm
,
osl
in
zip
(
na
.
inner_out_mit_mot
,
na
.
outer_out_mit_mot
,
na
.
mit_mot_out_slices
):
for
simm
,
somm
,
sosl
in
seen
:
for
simm
,
somm
,
sosl
in
seen
:
if
osl
==
sosl
and
equal_computations
(
imm
,
simm
,
left
,
right
):
if
osl
==
sosl
and
equal_computations
(
imm
,
simm
,
left
,
right
):
new_outer_out_mit_mot
.
append
(
somm
)
new_outer_out_mit_mot
.
append
(
somm
)
...
@@ -1120,7 +1140,7 @@ def scan_merge_inouts(node):
...
@@ -1120,7 +1140,7 @@ def scan_merge_inouts(node):
return
na
.
outer_outputs
return
na
.
outer_outputs
scan_seqopt
.
register
(
'scanOp_merge_inouts'
,
scan_seqopt
.
register
(
'scanOp_merge_inouts'
,
opt
.
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
opt
.
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
3
,
3
,
'fast_run'
,
'fast_run'
,
'scan'
)
'scan'
)
theano/scan_module/tests/test_scan.py
浏览文件 @
5a3a1d82
...
@@ -2260,6 +2260,40 @@ class T_Scan(unittest.TestCase):
...
@@ -2260,6 +2260,40 @@ class T_Scan(unittest.TestCase):
assert
numpy
.
allclose
(
vnh0
,
tnh0
,
atol
=
1e-6
)
assert
numpy
.
allclose
(
vnh0
,
tnh0
,
atol
=
1e-6
)
assert
numpy
.
allclose
(
vnW
,
tnW
,
atol
=
1e-6
)
assert
numpy
.
allclose
(
vnW
,
tnW
,
atol
=
1e-6
)
def
test_pushout_all
(
self
):
W1
=
tensor
.
matrix
(
'W1'
)
W2
=
tensor
.
matrix
(
'W2'
)
h0
=
tensor
.
vector
(
'h0'
)
def
lambda_fn
(
h
,
W1
,
W2
):
return
tensor
.
dot
(
h
,
W1
+
W2
)
o
,
_
=
theano
.
scan
(
lambda_fn
,
non_sequences
=
[
h0
,
W1
,
W2
],
n_steps
=
5
)
f
=
theano
.
function
([
h0
,
W1
,
W2
],
o
,
mode
=
mode_with_opt
)
scan_nodes
=
[
x
for
x
in
f
.
maker
.
env
.
toposort
()
if
isinstance
(
x
.
op
,
theano
.
scan_module
.
scan_op
.
Scan
)]
assert
len
(
scan_nodes
)
==
0
seed
=
utt
.
fetch_seed
()
rng
=
numpy
.
random
.
RandomState
(
seed
)
floatX
=
theano
.
config
.
floatX
v_h
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,)),
dtype
=
floatX
)
v_W1
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,
2
)),
dtype
=
floatX
)
v_W2
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,
2
)),
dtype
=
floatX
)
v_out
=
numpy
.
dot
(
v_h
,
v_W1
+
v_W2
)
sol
=
numpy
.
zeros
((
5
,
2
))
# This line is here to make sol have the same shape as the output of
# theano. Note that what we ask theano to do is to repeat the 2
# elements vector v_out 5 times
sol
[:,:]
=
v_out
assert
numpy
.
allclose
(
sol
,
f
(
v_h
,
v_W1
,
v_W2
))
def
test_pushout
(
self
):
def
test_pushout
(
self
):
W1
=
tensor
.
matrix
(
'W1'
)
W1
=
tensor
.
matrix
(
'W1'
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论