Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8b368cc8
提交
8b368cc8
authored
10月 07, 2016
作者:
Faruk Ahmed
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
flake8 for scan_op
上级
d2aef4d9
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
130 行增加
和
139 行删除
+130
-139
scan_op.py
theano/scan_module/scan_op.py
+130
-139
没有找到文件。
theano/scan_module/scan_op.py
浏览文件 @
8b368cc8
...
@@ -125,7 +125,7 @@ class Scan(PureOp):
...
@@ -125,7 +125,7 @@ class Scan(PureOp):
outputs
,
outputs
,
info
,
info
,
typeConstructor
=
None
,
typeConstructor
=
None
,
):
):
if
'gpua'
not
in
info
:
if
'gpua'
not
in
info
:
info
[
'gpua'
]
=
False
info
[
'gpua'
]
=
False
# adding properties into self
# adding properties into self
...
@@ -346,8 +346,8 @@ class Scan(PureOp):
...
@@ -346,8 +346,8 @@ class Scan(PureOp):
len
(
self
.
inner_shared
(
self
.
inputs
))
+
len
(
self
.
inner_shared
(
self
.
inputs
))
+
len
(
self
.
inner_non_seqs
(
self
.
inputs
)))
len
(
self
.
inner_non_seqs
(
self
.
inputs
)))
assert
n_outer_ins
==
n_inner_ins
,
\
assert
n_outer_ins
==
n_inner_ins
,
\
(
"The number of inputs given to the inner function of scan"
(
"The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan."
)
" does not match the number of inputs given to scan."
)
new_inputs
=
[
inputs
[
0
]]
new_inputs
=
[
inputs
[
0
]]
# assert dtype is consistent
# assert dtype is consistent
err_msg1
=
(
'When compiling the inner function of scan (the '
err_msg1
=
(
'When compiling the inner function of scan (the '
...
@@ -372,7 +372,7 @@ class Scan(PureOp):
...
@@ -372,7 +372,7 @@ class Scan(PureOp):
'have the same dimensionality, you can increase the '
'have the same dimensionality, you can increase the '
'dimensionality of the varialbe in the initial state of scan '
'dimensionality of the varialbe in the initial state of scan '
'by using dimshuffle or shape_padleft. '
'by using dimshuffle or shape_padleft. '
)
)
err_msg2
=
(
'When compiling the inner function of scan the '
err_msg2
=
(
'When compiling the inner function of scan the '
'following error has been encountered: The '
'following error has been encountered: The '
'initial state (`outputs_info` in scan nomenclature) '
'initial state (`outputs_info` in scan nomenclature) '
...
@@ -399,7 +399,7 @@ class Scan(PureOp):
...
@@ -399,7 +399,7 @@ class Scan(PureOp):
'have the same dimensionality, you can increase the '
'have the same dimensionality, you can increase the '
'dimensionality of the variable in the initial state of scan '
'dimensionality of the variable in the initial state of scan '
'by using dimshuffle or shape_padleft. '
'by using dimshuffle or shape_padleft. '
)
)
def
format
(
var
,
as_var
):
def
format
(
var
,
as_var
):
"""
"""
...
@@ -440,9 +440,9 @@ class Scan(PureOp):
...
@@ -440,9 +440,9 @@ class Scan(PureOp):
inner_mitmot
=
self
.
inner_mitmot
(
self
.
inputs
)
inner_mitmot
=
self
.
inner_mitmot
(
self
.
inputs
)
inner_mitmot_outs
=
self
.
inner_mitmot_outs
(
self
.
outputs
)
inner_mitmot_outs
=
self
.
inner_mitmot_outs
(
self
.
outputs
)
for
idx
,
(
itaps
,
otaps
,
_outer_mitmot
)
in
enumerate
(
for
idx
,
(
itaps
,
otaps
,
_outer_mitmot
)
in
enumerate
(
zip
(
self
.
mitmot_taps
(),
zip
(
self
.
mitmot_taps
(),
self
.
mitmot_out_taps
(),
self
.
mitmot_out_taps
(),
self
.
outer_mitmot
(
inputs
))):
self
.
outer_mitmot
(
inputs
))):
outer_mitmot
=
format
(
_outer_mitmot
,
as_var
=
inner_mitmot
[
ipos
])
outer_mitmot
=
format
(
_outer_mitmot
,
as_var
=
inner_mitmot
[
ipos
])
new_inputs
.
append
(
outer_mitmot
)
new_inputs
.
append
(
outer_mitmot
)
for
k
in
xrange
(
len
(
itaps
)):
for
k
in
xrange
(
len
(
itaps
)):
...
@@ -450,15 +450,15 @@ class Scan(PureOp):
...
@@ -450,15 +450,15 @@ class Scan(PureOp):
outer_mitmot
.
type
.
dtype
or
outer_mitmot
.
type
.
dtype
or
inner_mitmot
[
ipos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
inner_mitmot
[
ipos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
outer_mitmot
),
str
(
outer_mitmot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitmot
.
type
.
dtype
,
outer_mitmot
.
type
.
dtype
,
outer_mitmot
.
type
.
ndim
,
outer_mitmot
.
type
.
ndim
,
str
(
inner_mitmot
[
ipos
+
k
]),
str
(
inner_mitmot
[
ipos
+
k
]),
inner_mitmot
[
ipos
+
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
,
k
]
.
type
.
dtype
,
inner_mitmot
[
ipos
+
k
]
.
type
.
ndim
))
inner_mitmot
[
ipos
+
k
]
.
type
.
ndim
))
ipos
+=
len
(
itaps
)
ipos
+=
len
(
itaps
)
for
k
in
xrange
(
len
(
otaps
)):
for
k
in
xrange
(
len
(
otaps
)):
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
...
@@ -491,14 +491,14 @@ class Scan(PureOp):
...
@@ -491,14 +491,14 @@ class Scan(PureOp):
outer_mitsot
.
type
.
dtype
or
outer_mitsot
.
type
.
dtype
or
inner_mitsots
[
ipos
+
k
]
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
inner_mitsots
[
ipos
+
k
]
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
outer_mitsot
),
str
(
outer_mitsot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
ndim
,
outer_mitsot
.
type
.
ndim
,
str
(
inner_mitsots
[
ipos
+
k
]),
str
(
inner_mitsots
[
ipos
+
k
]),
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
,
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
,
inner_mitsots
[
ipos
+
k
]
.
type
.
ndim
))
inner_mitsots
[
ipos
+
k
]
.
type
.
ndim
))
ipos
+=
len
(
itaps
)
ipos
+=
len
(
itaps
)
if
inner_mitsot_out
.
type
.
dtype
!=
outer_mitsot
.
type
.
dtype
:
if
inner_mitsot_out
.
type
.
dtype
!=
outer_mitsot
.
type
.
dtype
:
raise
ValueError
(
err_msg2
%
raise
ValueError
(
err_msg2
%
...
@@ -523,14 +523,14 @@ class Scan(PureOp):
...
@@ -523,14 +523,14 @@ class Scan(PureOp):
new_inputs
.
append
(
outer_sitsot
)
new_inputs
.
append
(
outer_sitsot
)
if
(
inner_sitsot
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
if
(
inner_sitsot
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
outer_sitsot
),
str
(
outer_sitsot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
ndim
,
outer_sitsot
.
type
.
ndim
,
str
(
inner_sitsot
),
str
(
inner_sitsot
),
inner_sitsot
.
type
.
dtype
,
inner_sitsot
.
type
.
dtype
,
inner_sitsot
.
type
.
ndim
))
inner_sitsot
.
type
.
ndim
))
if
inner_sitsot_out
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
:
if
inner_sitsot_out
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
:
raise
ValueError
(
err_msg2
%
raise
ValueError
(
err_msg2
%
(
str
(
outer_sitsot
),
(
str
(
outer_sitsot
),
...
@@ -570,14 +570,14 @@ class Scan(PureOp):
...
@@ -570,14 +570,14 @@ class Scan(PureOp):
(
outer_shared
.
dtype
!=
inner_shared
.
dtype
or
(
outer_shared
.
dtype
!=
inner_shared
.
dtype
or
outer_shared
.
ndim
!=
inner_shared
.
ndim
)):
outer_shared
.
ndim
!=
inner_shared
.
ndim
)):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
outer_shared
),
str
(
outer_shared
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_shared
.
dtype
,
outer_shared
.
dtype
,
outer_shared
.
ndim
,
outer_shared
.
ndim
,
str
(
inner_shared
),
str
(
inner_shared
),
inner_shared
.
dtype
,
inner_shared
.
dtype
,
inner_shared
.
ndim
))
inner_shared
.
ndim
))
# We do not need to call `format` on outer_nisot arguments.
# We do not need to call `format` on outer_nisot arguments.
# outer_nitsot stands for no input tap single output tap. This means
# outer_nitsot stands for no input tap single output tap. This means
# these are states that do not feed anything back in the recurrent
# these are states that do not feed anything back in the recurrent
...
@@ -595,7 +595,7 @@ class Scan(PureOp):
...
@@ -595,7 +595,7 @@ class Scan(PureOp):
if
inner_nonseq
.
type
!=
outer_nonseq
.
type
:
if
inner_nonseq
.
type
!=
outer_nonseq
.
type
:
raise
ValueError
((
'Argument
%
s given to scan node does not'
raise
ValueError
((
'Argument
%
s given to scan node does not'
' match its correspondance
%
s'
)
%
' match its correspondance
%
s'
)
%
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
# For every nit_sot input we get as input a int/uint that
# For every nit_sot input we get as input a int/uint that
...
@@ -788,7 +788,7 @@ class Scan(PureOp):
...
@@ -788,7 +788,7 @@ class Scan(PureOp):
# Wrap the corresponding input as usual. Leave the
# Wrap the corresponding input as usual. Leave the
# output as-is.
# output as-is.
wrapped_inputs
.
append
(
In
(
self
.
inputs
[
input_idx
],
wrapped_inputs
.
append
(
In
(
self
.
inputs
[
input_idx
],
borrow
=
False
))
borrow
=
False
))
input_idx
+=
1
input_idx
+=
1
# Wrap the inputs not associated to mitmots and wrap the remaining
# Wrap the inputs not associated to mitmots and wrap the remaining
...
@@ -841,7 +841,7 @@ class Scan(PureOp):
...
@@ -841,7 +841,7 @@ class Scan(PureOp):
profile
=
None
profile
=
None
if
(
theano
.
config
.
profile
or
if
(
theano
.
config
.
profile
or
(
isinstance
(
self
.
profile
,
(
string_types
,
bool
,
integer_types
))
(
isinstance
(
self
.
profile
,
(
string_types
,
bool
,
integer_types
))
and
self
.
profile
)):
and
self
.
profile
)):
if
isinstance
(
self
.
profile
,
string_types
):
if
isinstance
(
self
.
profile
,
string_types
):
profile
=
ScanProfileStats
(
name
=
self
.
profile
)
profile
=
ScanProfileStats
(
name
=
self
.
profile
)
else
:
else
:
...
@@ -866,7 +866,7 @@ class Scan(PureOp):
...
@@ -866,7 +866,7 @@ class Scan(PureOp):
for
out
in
self
.
fn
.
maker
.
fgraph
.
outputs
]
for
out
in
self
.
fn
.
maker
.
fgraph
.
outputs
]
try
:
try
:
if
impl
==
'py'
:
if
impl
==
'py'
:
raise
theano
.
gof
.
cmodule
.
MissingGXX
raise
theano
.
gof
.
cmodule
.
MissingGXX
cython_mintaps
=
numpy
.
asarray
(
self
.
mintaps
,
dtype
=
'int32'
)
cython_mintaps
=
numpy
.
asarray
(
self
.
mintaps
,
dtype
=
'int32'
)
cython_tap_array_len
=
\
cython_tap_array_len
=
\
...
@@ -890,16 +890,16 @@ class Scan(PureOp):
...
@@ -890,16 +890,16 @@ class Scan(PureOp):
d1
=
numpy
.
max
(
cython_mit_mot_out_nslices
)
d1
=
numpy
.
max
(
cython_mit_mot_out_nslices
)
d0
=
len
(
self
.
mit_mot_out_slices
)
d0
=
len
(
self
.
mit_mot_out_slices
)
cython_mit_mot_out_slices
=
numpy
.
zeros
((
d0
,
d1
),
cython_mit_mot_out_slices
=
numpy
.
zeros
((
d0
,
d1
),
dtype
=
'int32'
)
dtype
=
'int32'
)
for
_d0
in
xrange
(
d0
):
for
_d0
in
xrange
(
d0
):
for
_d1
in
xrange
(
cython_mit_mot_out_nslices
[
_d0
]):
for
_d1
in
xrange
(
cython_mit_mot_out_nslices
[
_d0
]):
cython_mit_mot_out_slices
[
_d0
,
_d1
]
=
\
cython_mit_mot_out_slices
[
_d0
,
_d1
]
=
\
self
.
mit_mot_out_slices
[
_d0
][
_d1
]
self
.
mit_mot_out_slices
[
_d0
][
_d1
]
cython_vector_seqs
=
numpy
.
asarray
(
self
.
vector_seqs
,
cython_vector_seqs
=
numpy
.
asarray
(
self
.
vector_seqs
,
dtype
=
'int32'
)
dtype
=
'int32'
)
cython_vector_outs
=
numpy
.
asarray
(
self
.
vector_outs
,
cython_vector_outs
=
numpy
.
asarray
(
self
.
vector_outs
,
dtype
=
'int32'
)
dtype
=
'int32'
)
cython_mitmots_preallocated
=
numpy
.
asarray
(
self
.
mitmots_preallocated
,
cython_mitmots_preallocated
=
numpy
.
asarray
(
self
.
mitmots_preallocated
,
dtype
=
'int32'
)
dtype
=
'int32'
)
...
@@ -910,39 +910,38 @@ class Scan(PureOp):
...
@@ -910,39 +910,38 @@ class Scan(PureOp):
if
hasattr
(
self
,
'destroy_map'
):
if
hasattr
(
self
,
'destroy_map'
):
cython_destroy_map
=
[
x
in
self
.
destroy_map
cython_destroy_map
=
[
x
in
self
.
destroy_map
for
x
in
xrange
(
len
(
node
.
outputs
))]
for
x
in
xrange
(
len
(
node
.
outputs
))]
else
:
else
:
cython_destroy_map
=
[
0
for
x
in
xrange
(
len
(
node
.
outputs
))]
cython_destroy_map
=
[
0
for
x
in
xrange
(
len
(
node
.
outputs
))]
cython_destroy_map
=
numpy
.
asarray
(
cython_destroy_map
,
cython_destroy_map
=
numpy
.
asarray
(
cython_destroy_map
,
dtype
=
'int32'
)
dtype
=
'int32'
)
from
.
import
scan_perform_ext
from
.
import
scan_perform_ext
p
=
lambda
node
,
args
,
outs
:
\
p
=
lambda
node
,
args
,
outs
:
\
scan_perform_ext
.
perform
(
scan_perform_ext
.
perform
(
self
.
n_shared_outs
,
self
.
n_shared_outs
,
self
.
n_mit_mot_outs
,
self
.
n_mit_mot_outs
,
self
.
n_seqs
,
self
.
n_seqs
,
self
.
n_mit_mot
,
self
.
n_mit_mot
,
self
.
n_mit_sot
,
self
.
n_mit_sot
,
self
.
n_sit_sot
,
self
.
n_sit_sot
,
self
.
n_nit_sot
,
self
.
n_nit_sot
,
args
[
0
],
args
[
0
],
self
.
as_while
,
self
.
as_while
,
cython_mintaps
,
cython_mintaps
,
cython_tap_array
,
cython_tap_array
,
cython_tap_array_len
,
cython_tap_array_len
,
cython_vector_seqs
,
cython_vector_seqs
,
cython_vector_outs
,
cython_vector_outs
,
cython_mit_mot_out_slices
,
cython_mit_mot_out_slices
,
cython_mit_mot_out_nslices
,
cython_mit_mot_out_nslices
,
cython_mitmots_preallocated
,
cython_mitmots_preallocated
,
cython_inps_is_tensor
,
cython_inps_is_tensor
,
cython_outs_is_tensor
,
cython_outs_is_tensor
,
self
.
fn
.
fn
,
self
.
fn
.
fn
,
self
.
fn
,
self
.
fn
,
cython_destroy_map
,
cython_destroy_map
,
args
,
args
,
outs
,
outs
,
self
,
node
)
self
,
node
)
except
(
ImportError
,
theano
.
gof
.
cmodule
.
MissingGXX
):
except
(
ImportError
,
theano
.
gof
.
cmodule
.
MissingGXX
):
p
=
self
.
execute
p
=
self
.
execute
# default arguments are stored in the closure of `rval`
# default arguments are stored in the closure of `rval`
...
@@ -1004,8 +1003,8 @@ class Scan(PureOp):
...
@@ -1004,8 +1003,8 @@ class Scan(PureOp):
def
inner_mitsot
(
self
,
list_inputs
):
def
inner_mitsot
(
self
,
list_inputs
):
n_mitmot_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
n_mitmot_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
ntaps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
ntaps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_sot
)])
return
list_inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
return
list_inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
self
.
n_seqs
+
ntaps_upto_sit_sot
]
self
.
n_seqs
+
ntaps_upto_sit_sot
]
...
@@ -1094,7 +1093,7 @@ class Scan(PureOp):
...
@@ -1094,7 +1093,7 @@ class Scan(PureOp):
if
isinstance
(
list_outputs
,
Apply
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
list_outputs
=
list_outputs
.
outputs
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
self
.
n_nit_sot
)
return
list_outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
list_outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_non_seqs
(
self
,
list_inputs
):
def
inner_non_seqs
(
self
,
list_inputs
):
...
@@ -1153,10 +1152,10 @@ class Scan(PureOp):
...
@@ -1153,10 +1152,10 @@ class Scan(PureOp):
for
idx
,
seq
in
enumerate
(
args
[
1
:
self
.
seqs_arg_offset
]):
for
idx
,
seq
in
enumerate
(
args
[
1
:
self
.
seqs_arg_offset
]):
if
seq
.
shape
[
0
]
<
n_steps
:
if
seq
.
shape
[
0
]
<
n_steps
:
raise
ValueError
((
'Sequence is shorter then the required '
raise
ValueError
((
'Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'number of steps : (n_steps, seq, '
'seq.shape):'
),
n_steps
,
'seq.shape):'
),
n_steps
,
node
.
inputs
[
1
+
idx
],
node
.
inputs
[
1
+
idx
],
seq
.
shape
)
seq
.
shape
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
# 2. Allocate memory for the outputs. Construct the list:
# 2. Allocate memory for the outputs. Construct the list:
...
@@ -1165,15 +1164,15 @@ class Scan(PureOp):
...
@@ -1165,15 +1164,15 @@ class Scan(PureOp):
# output
# output
store_steps
=
[
arg
.
shape
[
0
]
for
arg
store_steps
=
[
arg
.
shape
[
0
]
for
arg
in
args
[
self
.
seqs_arg_offset
:
in
args
[
self
.
seqs_arg_offset
:
self
.
shared_arg_offset
]]
self
.
shared_arg_offset
]]
store_steps
+=
[
arg
for
arg
in
store_steps
+=
[
arg
for
arg
in
args
[
self
.
nit_sot_arg_offset
:
args
[
self
.
nit_sot_arg_offset
:
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
]
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
]
]
]
pos
=
[(
-
self
.
mintaps
[
idx
])
%
store_steps
[
idx
]
for
idx
pos
=
[(
-
self
.
mintaps
[
idx
])
%
store_steps
[
idx
]
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
)]
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
)]
if
not
getattr
(
self
,
'destroy_map'
,
None
):
if
not
getattr
(
self
,
'destroy_map'
,
None
):
self
.
destroy_map
=
OrderedDict
()
self
.
destroy_map
=
OrderedDict
()
# 2.1 Create storage space for outputs
# 2.1 Create storage space for outputs
...
@@ -1207,7 +1206,7 @@ class Scan(PureOp):
...
@@ -1207,7 +1206,7 @@ class Scan(PureOp):
old_output_data
=
[
None
]
*
len
(
output_storage
)
old_output_data
=
[
None
]
*
len
(
output_storage
)
fn
=
self
.
fn
.
fn
fn
=
self
.
fn
.
fn
offset
=
(
self
.
n_seqs
+
sum
(
map
(
len
,
self
.
tap_array
[:
self
.
n_outs
]))
+
offset
=
(
self
.
n_seqs
+
sum
(
map
(
len
,
self
.
tap_array
[:
self
.
n_outs
]))
+
self
.
n_shared_outs
)
self
.
n_shared_outs
)
for
idx
in
xrange
(
len
(
other_args
)):
for
idx
in
xrange
(
len
(
other_args
)):
input_storage
[
idx
+
offset
]
.
storage
[
0
]
=
other_args
[
idx
]
input_storage
[
idx
+
offset
]
.
storage
[
0
]
=
other_args
[
idx
]
...
@@ -1221,7 +1220,7 @@ class Scan(PureOp):
...
@@ -1221,7 +1220,7 @@ class Scan(PureOp):
for
idx
in
xrange
(
self
.
n_seqs
):
for
idx
in
xrange
(
self
.
n_seqs
):
if
self
.
vector_seqs
[
idx
]:
if
self
.
vector_seqs
[
idx
]:
input_storage
[
idx
]
.
storage
[
0
]
=
\
input_storage
[
idx
]
.
storage
[
0
]
=
\
seqs
[
idx
][
i
:
i
+
1
]
.
reshape
(())
seqs
[
idx
][
i
:
i
+
1
]
.
reshape
(())
else
:
else
:
input_storage
[
idx
]
.
storage
[
0
]
=
seqs
[
idx
][
i
]
input_storage
[
idx
]
.
storage
[
0
]
=
seqs
[
idx
][
i
]
...
@@ -1231,7 +1230,7 @@ class Scan(PureOp):
...
@@ -1231,7 +1230,7 @@ class Scan(PureOp):
for
tap
in
self
.
tap_array
[
idx
]:
for
tap
in
self
.
tap_array
[
idx
]:
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
input_storage
[
offset
]
.
storage
[
0
]
=
\
input_storage
[
offset
]
.
storage
[
0
]
=
\
outs
[
idx
][
0
][
_idx
:
_idx
+
1
]
.
reshape
(())
outs
[
idx
][
0
][
_idx
:
_idx
+
1
]
.
reshape
(())
offset
+=
1
offset
+=
1
else
:
else
:
for
tap
in
self
.
tap_array
[
idx
]:
for
tap
in
self
.
tap_array
[
idx
]:
...
@@ -1400,7 +1399,7 @@ class Scan(PureOp):
...
@@ -1400,7 +1399,7 @@ class Scan(PureOp):
# This output tap has not been preallocated, recover
# This output tap has not been preallocated, recover
# its value as usual
# its value as usual
outs
[
j
][
0
][
k
+
pos
[
j
]]
=
\
outs
[
j
][
0
][
k
+
pos
[
j
]]
=
\
output_storage
[
offset_out
]
.
storage
[
0
]
output_storage
[
offset_out
]
.
storage
[
0
]
offset_out
+=
1
offset_out
+=
1
mitmot_out_idx
+=
1
mitmot_out_idx
+=
1
...
@@ -1417,7 +1416,7 @@ class Scan(PureOp):
...
@@ -1417,7 +1416,7 @@ class Scan(PureOp):
# Copy the output value to `outs`, if necessary
# Copy the output value to `outs`, if necessary
if
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]:
if
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]:
outs
[
j
][
0
][
pos
[
j
]]
=
\
outs
[
j
][
0
][
pos
[
j
]]
=
\
output_storage
[
offset_out
+
j
]
.
storage
[
0
]
output_storage
[
offset_out
+
j
]
.
storage
[
0
]
else
:
else
:
# Check whether the initialization of the output storage
# Check whether the initialization of the output storage
# map for this output has been reused.
# map for this output has been reused.
...
@@ -1446,7 +1445,7 @@ class Scan(PureOp):
...
@@ -1446,7 +1445,7 @@ class Scan(PureOp):
if
i
==
0
:
if
i
==
0
:
jout
=
j
+
offset_out
jout
=
j
+
offset_out
shape
=
(
store_steps
[
j
],)
+
\
shape
=
(
store_steps
[
j
],)
+
\
output_storage
[
jout
]
.
storage
[
0
]
.
shape
output_storage
[
jout
]
.
storage
[
0
]
.
shape
if
len
(
output_storage
[
jout
]
.
storage
[
0
]
.
shape
)
==
0
:
if
len
(
output_storage
[
jout
]
.
storage
[
0
]
.
shape
)
==
0
:
self
.
vector_outs
[
j
]
=
True
self
.
vector_outs
[
j
]
=
True
dtype
=
output_storage
[
jout
]
.
storage
[
0
]
.
dtype
dtype
=
output_storage
[
jout
]
.
storage
[
0
]
.
dtype
...
@@ -1490,7 +1489,7 @@ class Scan(PureOp):
...
@@ -1490,7 +1489,7 @@ class Scan(PureOp):
outs
[
j
][
0
]
=
output_storage
[
jout
]
.
storage
[
0
]
outs
[
j
][
0
]
=
output_storage
[
jout
]
.
storage
[
0
]
pos
=
[(
idx
+
1
)
%
store
for
idx
,
store
in
pos
=
[(
idx
+
1
)
%
store
for
idx
,
store
in
izip
(
pos
,
store_steps
)]
izip
(
pos
,
store_steps
)]
i
=
i
+
1
i
=
i
+
1
# 6. Check if you need to re-order output buffers
# 6. Check if you need to re-order output buffers
...
@@ -1654,17 +1653,15 @@ class Scan(PureOp):
...
@@ -1654,17 +1653,15 @@ class Scan(PureOp):
self_outs
=
self
.
outputs
[:
-
1
]
self_outs
=
self
.
outputs
[:
-
1
]
else
:
else
:
self_outs
=
self
.
outputs
self_outs
=
self
.
outputs
outs_shape
=
scan_utils
.
infer_shape
(
outs_shape
=
scan_utils
.
infer_shape
(
outs
=
self_outs
,
outs
=
self_outs
,
inputs
=
self
.
inputs
,
inputs
=
self
.
inputs
,
input_shapes
=
inner_ins_shapes
)
input_shapes
=
inner_ins_shapes
)
# Will be used to check if outs_shape can be expressed without using
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs.
# variables in self.inputs.
# The shapes of node.inputs are valid.
# The shapes of node.inputs are valid.
validator
=
scan_utils
.
Validator
(
validator
=
scan_utils
.
Validator
(
valid
=
input_shapes
,
valid
=
input_shapes
,
invalid
=
self
.
inputs
,
invalid
=
self
.
inputs
,
valid_equivalent
=
out_equivalent
)
valid_equivalent
=
out_equivalent
)
offset
=
1
+
self
.
n_seqs
offset
=
1
+
self
.
n_seqs
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
...
@@ -1699,7 +1696,7 @@ class Scan(PureOp):
...
@@ -1699,7 +1696,7 @@ class Scan(PureOp):
scan_outs
.
append
(
tuple
(
shp
))
scan_outs
.
append
(
tuple
(
shp
))
scan_outs
+=
[
x
for
x
in
scan_outs
+=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
self
.
n_shared_outs
]]
input_shapes
[
offset
:
offset
+
self
.
n_shared_outs
]]
# if we are dealing with a repeat-until, then we do not know the
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
# leading dimension so we replace it for every entry with Shape_i
if
self
.
as_while
:
if
self
.
as_while
:
...
@@ -1763,7 +1760,7 @@ class Scan(PureOp):
...
@@ -1763,7 +1760,7 @@ class Scan(PureOp):
j_inp_idx
=
self
.
var_mappings
[
"outer_inp_from_outer_out"
][
jidx
]
j_inp_idx
=
self
.
var_mappings
[
"outer_inp_from_outer_out"
][
jidx
]
if
j_inp_idx
!=
-
1
:
if
j_inp_idx
!=
-
1
:
if
connection_pattern
[
j_inp_idx
][
iidx
]
==
True
:
if
connection_pattern
[
j_inp_idx
][
iidx
]
==
True
:
for
k
in
xrange
(
len
(
connection_pattern
)):
for
k
in
xrange
(
len
(
connection_pattern
)):
if
connection_pattern
[
k
][
jidx
]:
if
connection_pattern
[
k
][
jidx
]:
connection_pattern
[
k
][
iidx
]
=
True
connection_pattern
[
k
][
iidx
]
=
True
...
@@ -1887,18 +1884,18 @@ class Scan(PureOp):
...
@@ -1887,18 +1884,18 @@ class Scan(PureOp):
# With the global mapping inferred, the individual mappings
# With the global mapping inferred, the individual mappings
# can be produced
# can be produced
mappings
=
{
"outer_inp_from_outer_out"
:
{},
mappings
=
{
"outer_inp_from_outer_out"
:
{},
"inner_inp_from_outer_out"
:
{},
"inner_inp_from_outer_out"
:
{},
"inner_out_from_outer_out"
:
{},
"inner_out_from_outer_out"
:
{},
"inner_inp_from_outer_inp"
:
{},
"inner_inp_from_outer_inp"
:
{},
"inner_out_from_outer_inp"
:
{},
"inner_out_from_outer_inp"
:
{},
"outer_out_from_outer_inp"
:
{},
"outer_out_from_outer_inp"
:
{},
"outer_inp_from_inner_inp"
:
{},
"outer_inp_from_inner_inp"
:
{},
"inner_out_from_inner_inp"
:
{},
"inner_out_from_inner_inp"
:
{},
"outer_out_from_inner_inp"
:
{},
"outer_out_from_inner_inp"
:
{},
"outer_inp_from_inner_out"
:
{},
"outer_inp_from_inner_out"
:
{},
"inner_inp_from_inner_out"
:
{},
"inner_inp_from_inner_out"
:
{},
"outer_out_from_inner_out"
:
{}}
"outer_out_from_inner_out"
:
{}}
for
(
oinp
,
iinp
,
iout
,
oout
)
in
izip
(
outer_input_indices
,
for
(
oinp
,
iinp
,
iout
,
oout
)
in
izip
(
outer_input_indices
,
inner_input_indices
,
inner_input_indices
,
...
@@ -1944,7 +1941,7 @@ class Scan(PureOp):
...
@@ -1944,7 +1941,7 @@ class Scan(PureOp):
grad_steps
=
self
.
outer_sitsot_outs
(
outs
)[
0
]
.
shape
[
0
]
-
1
grad_steps
=
self
.
outer_sitsot_outs
(
outs
)[
0
]
.
shape
[
0
]
-
1
elif
self
.
n_mit_sot
>
0
:
elif
self
.
n_mit_sot
>
0
:
grad_steps
=
self
.
outer_mitsot_outs
(
outs
)[
0
]
.
shape
[
0
]
+
\
grad_steps
=
self
.
outer_mitsot_outs
(
outs
)[
0
]
.
shape
[
0
]
+
\
self
.
mintaps
[
self
.
n_mit_mot
]
self
.
mintaps
[
self
.
n_mit_mot
]
else
:
else
:
grad_steps
=
inputs
[
0
]
grad_steps
=
inputs
[
0
]
...
@@ -2031,14 +2028,13 @@ class Scan(PureOp):
...
@@ -2031,14 +2028,13 @@ class Scan(PureOp):
# to X.
# to X.
known_grads
=
OrderedDict
([(
k
.
copy
(),
v
)
for
(
k
,
v
)
in
known_grads
.
items
()])
known_grads
=
OrderedDict
([(
k
.
copy
(),
v
)
for
(
k
,
v
)
in
known_grads
.
items
()])
grads
=
gradient
.
grad
(
grads
=
gradient
.
grad
(
cost
=
None
,
cost
=
None
,
known_grads
=
known_grads
,
known_grads
=
known_grads
,
wrt
=
wrt
,
wrt
=
wrt
,
consider_constant
=
wrt
,
consider_constant
=
wrt
,
disconnected_inputs
=
'ignore'
,
disconnected_inputs
=
'ignore'
,
return_disconnected
=
'None'
,
return_disconnected
=
'None'
,
null_gradients
=
'return'
)
null_gradients
=
'return'
)
for
i
in
range
(
len
(
wrt
)):
for
i
in
range
(
len
(
wrt
)):
gmp
[
wrt
[
i
]]
=
grads
[
i
]
gmp
[
wrt
[
i
]]
=
grads
[
i
]
...
@@ -2098,7 +2094,6 @@ class Scan(PureOp):
...
@@ -2098,7 +2094,6 @@ class Scan(PureOp):
dC_dXt
=
safe_new
(
dC_douts
[
idx
][
0
])
dC_dXt
=
safe_new
(
dC_douts
[
idx
][
0
])
dC_dXts
.
append
(
dC_dXt
)
dC_dXts
.
append
(
dC_dXt
)
known_grads
=
OrderedDict
()
known_grads
=
OrderedDict
()
dc_dxts_idx
=
0
dc_dxts_idx
=
0
for
i
in
range
(
len
(
diff_outputs
)):
for
i
in
range
(
len
(
diff_outputs
)):
...
@@ -2153,7 +2148,7 @@ class Scan(PureOp):
...
@@ -2153,7 +2148,7 @@ class Scan(PureOp):
dC_dXtm1s
.
append
(
safe_new
(
dC_dXts
[
opos
]))
dC_dXtm1s
.
append
(
safe_new
(
dC_dXts
[
opos
]))
if
hasattr
(
x
,
'dtype'
)
and
x
.
dtype
!=
dC_dXts
[
opos
]
.
dtype
:
if
hasattr
(
x
,
'dtype'
)
and
x
.
dtype
!=
dC_dXts
[
opos
]
.
dtype
:
dC_dinps_t
[
pos
+
self
.
n_seqs
]
=
\
dC_dinps_t
[
pos
+
self
.
n_seqs
]
=
\
x
.
astype
(
dC_dXts
[
opos
]
.
dtype
)
x
.
astype
(
dC_dXts
[
opos
]
.
dtype
)
else
:
else
:
dC_dXtm1s
.
append
(
safe_new
(
x
))
dC_dXtm1s
.
append
(
safe_new
(
x
))
...
@@ -2180,7 +2175,7 @@ class Scan(PureOp):
...
@@ -2180,7 +2175,7 @@ class Scan(PureOp):
seq
=
outs
[
idx
]
seq
=
outs
[
idx
]
for
k
in
self
.
tap_array
[
idx
]:
for
k
in
self
.
tap_array
[
idx
]:
if
outmaxtap
-
k
!=
0
:
if
outmaxtap
-
k
!=
0
:
nw_seq
=
seq
[
k
-
mintap
:
-
(
outmaxtap
-
k
)][::
-
1
]
nw_seq
=
seq
[
k
-
mintap
:
-
(
outmaxtap
-
k
)][::
-
1
]
else
:
else
:
nw_seq
=
seq
[
k
-
mintap
:][::
-
1
]
nw_seq
=
seq
[
k
-
mintap
:][::
-
1
]
outer_inp_seqs
.
append
(
nw_seq
)
outer_inp_seqs
.
append
(
nw_seq
)
...
@@ -2288,7 +2283,6 @@ class Scan(PureOp):
...
@@ -2288,7 +2283,6 @@ class Scan(PureOp):
new_inner_out_mitmot
=
theano
.
clone
(
new_inner_out_mitmot
,
new_inner_out_mitmot
=
theano
.
clone
(
new_inner_out_mitmot
,
replace
=
[(
to_replace
,
replacement
)])
replace
=
[(
to_replace
,
replacement
)])
inner_out_mitmot
.
append
(
new_inner_out_mitmot
)
inner_out_mitmot
.
append
(
new_inner_out_mitmot
)
if
not
disconnected_dC_dinps_t
[
ins_pos
]:
if
not
disconnected_dC_dinps_t
[
ins_pos
]:
...
@@ -2553,8 +2547,7 @@ class Scan(PureOp):
...
@@ -2553,8 +2547,7 @@ class Scan(PureOp):
gradients
.
append
(
NullType
(
t
)())
gradients
.
append
(
NullType
(
t
)())
end
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
end
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
for
p
,
(
x
,
t
)
in
enumerate
(
for
p
,
(
x
,
t
)
in
enumerate
(
zip
(
outputs
[:
end
],
type_outs
[:
end
])):
zip
(
outputs
[:
end
],
type_outs
[:
end
])):
if
t
==
'connected'
:
if
t
==
'connected'
:
gradients
.
append
(
x
[::
-
1
])
gradients
.
append
(
x
[::
-
1
])
elif
t
==
'disconnected'
:
elif
t
==
'disconnected'
:
...
@@ -2587,12 +2580,11 @@ class Scan(PureOp):
...
@@ -2587,12 +2580,11 @@ class Scan(PureOp):
start
=
len
(
gradients
)
start
=
len
(
gradients
)
gradients
+=
[
DisconnectedType
()()
gradients
+=
[
DisconnectedType
()()
for
x
in
xrange
(
self
.
n_nit_sot
)]
for
x
in
xrange
(
self
.
n_nit_sot
)]
begin
=
end
begin
=
end
end
=
begin
+
n_sitsot_outs
end
=
begin
+
n_sitsot_outs
for
p
,
(
x
,
t
)
in
enumerate
(
for
p
,
(
x
,
t
)
in
enumerate
(
zip
(
outputs
[
begin
:
end
],
type_outs
[
begin
:
end
])):
zip
(
outputs
[
begin
:
end
],
type_outs
[
begin
:
end
])):
if
t
==
'connected'
:
if
t
==
'connected'
:
gradients
.
append
(
x
[
-
1
])
gradients
.
append
(
x
[
-
1
])
elif
t
==
'disconnected'
:
elif
t
==
'disconnected'
:
...
@@ -2629,7 +2621,7 @@ class Scan(PureOp):
...
@@ -2629,7 +2621,7 @@ class Scan(PureOp):
self
.
outputs
,
'_rop'
)
self
.
outputs
,
'_rop'
)
self_inputs
=
rval
[
0
]
self_inputs
=
rval
[
0
]
rop_of_inputs
=
rval
[
0
][:
self
.
n_seqs
+
self
.
n_outs
]
+
\
rop_of_inputs
=
rval
[
0
][:
self
.
n_seqs
+
self
.
n_outs
]
+
\
rval
[
0
][
self
.
n_seqs
+
self
.
n_outs
+
self
.
n_shared_outs
:]
rval
[
0
][
self
.
n_seqs
+
self
.
n_outs
+
self
.
n_shared_outs
:]
self_outputs
=
rval
[
1
]
self_outputs
=
rval
[
1
]
# Step 1. Compute the R_op of the inner function
# Step 1. Compute the R_op of the inner function
inner_eval_points
=
[
scan_utils
.
safe_new
(
x
,
'_evalpoint'
)
inner_eval_points
=
[
scan_utils
.
safe_new
(
x
,
'_evalpoint'
)
...
@@ -2640,8 +2632,7 @@ class Scan(PureOp):
...
@@ -2640,8 +2632,7 @@ class Scan(PureOp):
rop_self_outputs
=
self_outputs
rop_self_outputs
=
self_outputs
if
self
.
info
[
'n_shared_outs'
]
>
0
:
if
self
.
info
[
'n_shared_outs'
]
>
0
:
rop_self_outputs
=
rop_self_outputs
[:
-
self
.
info
[
'n_shared_outs'
]]
rop_self_outputs
=
rop_self_outputs
[:
-
self
.
info
[
'n_shared_outs'
]]
rop_outs
=
tensor
.
Rop
(
rop_self_outputs
,
rop_of_inputs
,
rop_outs
=
tensor
.
Rop
(
rop_self_outputs
,
rop_of_inputs
,
inner_eval_points
)
inner_eval_points
)
if
type
(
rop_outs
)
not
in
(
list
,
tuple
):
if
type
(
rop_outs
)
not
in
(
list
,
tuple
):
rop_outs
=
[
rop_outs
]
rop_outs
=
[
rop_outs
]
# Step 2. Figure out what corresponds to what in the scan
# Step 2. Figure out what corresponds to what in the scan
...
@@ -2721,8 +2712,8 @@ class Scan(PureOp):
...
@@ -2721,8 +2712,8 @@ class Scan(PureOp):
e
=
e
+
self
.
n_mit_sot
e
=
e
+
self
.
n_mit_sot
ib
=
ie
ib
=
ie
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
self
.
tap_array
[
self
.
n_mit_mot
:
\
self
.
tap_array
[
self
.
n_mit_mot
:
\
self
.
n_mit_mot
+
self
.
n_mit_sot
]]))
self
.
n_mit_mot
+
self
.
n_mit_sot
]]))
clean_eval_points
=
[]
clean_eval_points
=
[]
for
inp
,
evp
in
zip
(
inputs
[
b
:
e
],
eval_points
[
b
:
e
]):
for
inp
,
evp
in
zip
(
inputs
[
b
:
e
],
eval_points
[
b
:
e
]):
if
evp
is
not
None
:
if
evp
is
not
None
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论