Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9d9d9020
提交
9d9d9020
authored
12月 14, 2011
作者:
nouiz
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #298 from pascanur/scan_check
Scan check
上级
5c0887dd
da5ea343
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
238 行增加
和
147 行删除
+238
-147
basic.py
theano/scalar/basic.py
+4
-0
scan_op.py
theano/scan_module/scan_op.py
+203
-114
scan_opt.py
theano/scan_module/scan_opt.py
+31
-33
没有找到文件。
theano/scalar/basic.py
浏览文件 @
9d9d9020
...
@@ -119,6 +119,7 @@ class Scalar(Type):
...
@@ -119,6 +119,7 @@ class Scalar(Type):
TODO: refactor to be named ScalarType for consistency with TensorType
TODO: refactor to be named ScalarType for consistency with TensorType
"""
"""
ndim
=
0
def
__init__
(
self
,
dtype
):
def
__init__
(
self
,
dtype
):
if
dtype
==
'floatX'
:
if
dtype
==
'floatX'
:
...
@@ -441,6 +442,9 @@ all_types = discrete_types + continuous_types
...
@@ -441,6 +442,9 @@ all_types = discrete_types + continuous_types
class
_scalar_py_operators
:
class
_scalar_py_operators
:
# So that we can simplify checking code when we have a mixture of Scalar
# variables and Tensor variables
ndim
=
0
#UNARY
#UNARY
def
__abs__
(
self
):
return
abs_
(
self
)
def
__abs__
(
self
):
return
abs_
(
self
)
...
...
theano/scan_module/scan_op.py
浏览文件 @
9d9d9020
...
@@ -149,11 +149,36 @@ class Scan(PureOp):
...
@@ -149,11 +149,36 @@ class Scan(PureOp):
self
.
_hash_inner_graph
=
self
.
info
[
'gpu_hash'
]
self
.
_hash_inner_graph
=
self
.
info
[
'gpu_hash'
]
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
"""
Conventions:
inner_? - the variable corresponding to ? in the inner function
of scan (the lambda function executed at every time
step)
outer_? - the variable corresponding to ? in the outer graph,
i.e. the main graph (where the scan op lives)
inner_?_out - the variable representing the new value of ? after
executing one step of scan (i.e. outputs given by
the inner function)
"""
assert
numpy
.
all
(
isinstance
(
i
,
gof
.
Variable
)
for
i
in
inputs
)
assert
numpy
.
all
(
isinstance
(
i
,
gof
.
Variable
)
for
i
in
inputs
)
# Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan
n_outer_ins
=
len
(
inputs
)
-
len
(
self
.
outer_nitsot
(
inputs
))
-
1
n_inner_ins
=
(
len
(
self
.
inner_seqs
(
self
.
inputs
))
+
len
(
self
.
mitmot_taps
())
+
len
(
self
.
mitsot_taps
())
+
len
(
self
.
inner_sitsot
(
self
.
inputs
))
+
len
(
self
.
inner_shared
(
self
.
inputs
))
+
len
(
self
.
inner_non_seqs
(
self
.
inputs
)))
assert
n_outer_ins
==
n_inner_ins
,
\
(
"The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan."
)
# assert dtype is consistent
# assert dtype is consistent
err_msg1
=
(
'When compiling the inner function of scan the '
err_msg1
=
(
'When compiling the inner function of scan the '
'following error has been encountered: The '
'following error has been encountered: The '
'
%
s
%
s
( the entry
number
%
d) has dtype '
'
%
s
%
s
(argument
number
%
d) has dtype '
'
%
s. The corresponding slice
%
s however has'
'
%
s. The corresponding slice
%
s however has'
' dtype
%
s. This should never happen, please '
' dtype
%
s. This should never happen, please '
'report to theano-dev mailing list'
'report to theano-dev mailing list'
...
@@ -161,7 +186,7 @@ class Scan(PureOp):
...
@@ -161,7 +186,7 @@ class Scan(PureOp):
err_msg2
=
(
'When compiling the inner function of scan the '
err_msg2
=
(
'When compiling the inner function of scan the '
'following error has been encountered: The '
'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)'
'initial state (outputs_info in scan nomenclature)'
'of variable
%
s (
the entry
number
%
d)'
'of variable
%
s (
argument
number
%
d)'
' has dtype
%
s, while the result of the'
' has dtype
%
s, while the result of the'
' inner function for this output has dtype
%
s. This '
' inner function for this output has dtype
%
s. This '
'could happen if the inner graph of scan results in '
'could happen if the inner graph of scan results in '
...
@@ -185,85 +210,145 @@ class Scan(PureOp):
...
@@ -185,85 +210,145 @@ class Scan(PureOp):
# Check if input sequences and variables representing a slice of
# Check if input sequences and variables representing a slice of
# them have the same dtype
# them have the same dtype
for
idx
in
xrange
(
self
.
n_seqs
):
argoffset
=
0
if
inputs
[
1
+
idx
]
.
dtype
!=
self
.
inputs
[
idx
]
.
dtype
:
for
idx
,
(
inner_seq
,
outer_seq
)
in
enumerate
(
zip
(
self
.
inner_seqs
(
self
.
inputs
),
self
.
outer_seqs
(
inputs
))):
if
inner_seq
.
type
.
dtype
!=
outer_seq
[
idx
]
.
type
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'sequence'
,
raise
ValueError
(
err_msg1
%
(
'sequence'
,
str
(
inputs
[
1
+
idx
]
),
str
(
outer_seq
),
idx
,
idx
,
inputs
[
1
+
idx
]
.
dtype
,
outer_seq
.
type
.
dtype
,
str
(
self
.
inputs
[
idx
]
),
str
(
inner_seq
),
self
.
inputs
[
idx
]
.
dtype
))
inner_seq
.
type
.
dtype
))
argoffset
+=
len
(
self
.
outer_seqs
(
inputs
))
# Check that this 3 things have the same dtype for mit_mot:
# Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output
# - initial state of the output
# - variable representing an input slice of the otuput
# - variable representing an input slice of the otuput
# - variable representing an output slice of the otuput
# - variable representing an output slice of the otuput
# Maybe checking that ndim fits would be good as well !?
ipos
=
0
index_i
=
self
.
n_seqs
opos
=
0
index_o
=
0
inner_mitmot
=
self
.
inner_mitmot
(
self
.
inputs
)
index
=
1
+
self
.
n_seqs
inner_mitmot_outs
=
self
.
inner_mitmot_outs
(
self
.
outputs
)
start
=
index
for
idx
,
(
itaps
,
otaps
,
outer_mitmot
)
in
enumerate
(
end
=
index
+
self
.
n_mit_mot
zip
(
self
.
mitmot_taps
(),
while
index
<
end
:
self
.
mitmot_out_taps
(),
for
k
in
self
.
tap_array
[
index
-
start
]:
self
.
outer_mitmot
(
inputs
))):
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
for
k
in
xrange
(
len
(
itaps
)):
if
(
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
!=
outer_mitmot
.
type
.
dtype
or
inner_mitmot
[
ipos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
inputs
[
index
]),
str
(
outer_mitmot
),
index
,
argoffset
+
idx
,
inputs
[
index
]
.
dtype
,
outer_mitmot
.
type
.
dtype
,
str
(
self
.
inputs
[
index_i
]),
str
(
inner_mitmot
[
ipos
+
k
]),
self
.
inputs
[
index_i
]
.
dtype
))
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
))
index_i
+=
1
ipos
+=
len
(
itaps
)
for
k
in
self
.
mit_mot_out_slices
[
index
-
start
]:
for
k
in
xrange
(
len
(
otaps
)):
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
]),
outer_mitmot
.
type
.
dtype
or
index
,
inner_mitmot_outs
[
opos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
inputs
[
index
]
.
dtype
,
raise
ValueError
(
err_msg2
%
self
.
outputs
[
index_o
]
.
dtype
))
(
str
(
outer_mitmot
,
index_o
+=
1
argoffset
+
idx
,
index
+=
1
outer_mitmot
.
type
.
dtype
,
# Same checks as above but for outputs of type mit_sot and sit_sot
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
)))
end
+=
self
.
n_mit_sot
+
self
.
n_sit_sot
opos
+=
len
(
otaps
)
while
index
<
end
:
argoffset
+=
len
(
self
.
outer_mitmot
(
inputs
))
for
k
in
self
.
tap_array
[
index
-
start
]:
# Same checks as above but for outputs of type mit_sot
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
ipos
=
0
raise
ValueError
(
err_msg1
%
(
'Initial state'
,
inner_mitsots
=
self
.
inner_mitsot
(
self
.
inputs
)
str
(
inputs
[
index
]),
for
idx
,
(
itaps
,
outer_mitsot
,
inner_mitsot_out
)
in
enumerate
(
index
,
zip
(
self
.
mitsot_taps
(),
inputs
[
index
]
.
dtype
,
self
.
outer_mitsot
(
inputs
),
str
(
self
.
inputs
[
index_i
]),
self
.
inner_mitsot_outs
(
self
.
outputs
))):
self
.
inputs
[
index_i
]
.
dtype
))
for
k
in
xrange
(
len
(
itaps
)):
index_i
+=
1
if
(
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
!=
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
outer_mitsot
.
type
.
dtype
or
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
]),
inner_mitsots
[
ipos
+
k
]
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
index
,
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
inputs
[
index
]
.
dtype
,
' in scan nomenclature) '
,
self
.
outputs
[
index_o
]
.
dtype
))
str
(
outer_mitsot
),
index_o
+=
1
argoffset
+
idx
,
index
+=
1
outer_mitsot
.
type
.
dtype
,
str
(
inner_mitsot
[
ipos
+
k
]),
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
))
ipos
+=
len
(
itaps
)
if
(
inner_mitsot_out
.
type
.
dtype
!=
outer_mitsot
.
type
.
dtype
or
inner_mitsot_out
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
(
str
(
outer_mitsot
,
argoffset
+
idx
,
outer_mitsot
.
type
.
dtype
,
inner_mitsot_out
.
type
.
dtype
)))
argoffset
+=
len
(
self
.
outer_mitsot
(
inputs
))
# Same checks as above but for outputs of type sit_sot
for
idx
,
(
inner_sitsot
,
outer_sitsot
,
inner_sitsot_out
)
in
enumerate
(
zip
(
self
.
inner_sitsot
(
self
.
inputs
),
self
.
outer_sitsot
(
inputs
),
self
.
inner_sitsot_outs
(
self
.
outputs
))):
if
(
inner_sitsot
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
or
inner_sitsot
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
str
(
outer_sitsot
),
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
str
(
inner_sitsot
),
inner_sitsot
.
type
.
dtype
))
if
(
inner_sitsot_out
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
or
inner_sitsot_out
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
(
str
(
outer_sitsot
,
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
inner_sitsot_out
.
type
.
dtype
)))
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
# Check that the shared variable and their update rule have the same
# Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?!
# dtype. Maybe even same type ?!
end
+=
self
.
n_shared_outs
for
idx
,
(
inner_shared
,
inner_shared_out
,
outer_shared
)
in
enumerate
(
index_o
+=
self
.
n_nit_sot
zip
(
self
.
inner_shared
(
self
.
inputs
),
while
index
<
end
:
self
.
inner_shared_outs
(
self
.
outputs
),
if
(
hasattr
(
inputs
[
index
],
'dtype'
)
and
self
.
outer_shared
(
inputs
))):
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
):
if
(
hasattr
(
outer_shared
,
'dtype'
)
and
raise
ValueError
(
err_msg2
%
(
str
(
inputs
[
index
]),
(
outer_shared
.
dtype
!=
inner_shared_out
.
dtype
or
index
,
outer_shared
.
ndim
!=
inner_shared_out
.
ndim
)):
inputs
[
index
]
.
dtype
,
raise
ValueError
(
err_msg2
%
(
str
(
outer_shared
),
self
.
outputs
[
index_o
]
.
dtype
))
idx
+
argoffset
,
index
+=
1
outer_shared
.
dtype
,
index_o
+=
1
inner_shared_out
.
dtype
))
for
x
in
inputs
[
index
:
index
+
self
.
n_nit_sot
]:
if
(
hasattr
(
outer_shared
,
'dtype'
)
and
(
outer_shared
.
dtype
!=
inner_shared
.
dtype
or
outer_shared
.
ndim
!=
inner_shared
.
ndim
)):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
str
(
outer_shared
),
argoffset
+
idx
,
outer_shared
.
dtype
,
str
(
inner_shared
),
inner_shared
.
dtype
))
for
inner_nonseq
,
outer_nonseq
in
zip
(
self
.
inner_non_seqs
(
self
.
inputs
),
self
.
outer_non_seqs
(
inputs
)):
if
(
inner_nonseq
.
type
.
dtype
!=
outer_nonseq
.
type
.
dtype
or
inner_nonseq
.
type
.
ndim
!=
outer_nonseq
.
type
.
ndim
):
raise
ValueError
((
'Argument
%
s given to scan node does not'
' match its correspondance
%
s'
)
%
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
# For every nit_sot input we get as input a int/uint that
# For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is
# depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization
# used by truncated BPTT and by scan space optimization
if
(
str
(
x
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
or
if
(
str
(
outer_nitsot
.
type
.
dtype
)[:
3
]
not
in
(
'uin'
,
'int'
)
or
x
.
ndim
!=
0
):
outer_nitsot
.
ndim
!=
0
):
raise
ValueError
(
'For output
%
d
you need to provide a '
raise
ValueError
(
'For output
%
s
you need to provide a '
'scalar int !'
,
x
)
'scalar int !'
,
str
(
outer_nitsot
)
)
apply_node
=
Apply
(
self
,
apply_node
=
Apply
(
self
,
inputs
,
inputs
,
...
@@ -459,25 +544,29 @@ class Scan(PureOp):
...
@@ -459,25 +544,29 @@ class Scan(PureOp):
rval
.
lazy
=
False
rval
.
lazy
=
False
return
rval
return
rval
def
inner_seqs
(
self
):
def
inner_seqs
(
self
,
list_inputs
):
return
self
.
inputs
[:
self
.
n_seqs
]
# Given the list of inner inputs this function grabs those
# corresponding to sequences
return
list_inputs
[:
self
.
n_seqs
]
def
outer_seqs
(
self
,
node
):
def
outer_seqs
(
self
,
list_inputs
):
return
node
.
inputs
[
1
:
1
+
self
.
n_seqs
]
# Given the list of outter inputs this function grabs those
# corresponding to sequences
return
list_inputs
[
1
:
1
+
self
.
n_seqs
]
def
inner_mitmot
(
self
):
def
inner_mitmot
(
self
,
list_inputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
return
self
.
inputs
[
self
.
n_seqs
:
self
.
n_seqs
+
n_taps
]
return
list_
inputs
[
self
.
n_seqs
:
self
.
n_seqs
+
n_taps
]
def
outer_mitmot
(
self
,
node
):
def
outer_mitmot
(
self
,
list_inputs
):
return
node
.
inputs
[
1
+
self
.
n_seqs
:
1
+
self
.
n_seqs
+
self
.
n_mit_mot
]
return
list_
inputs
[
1
+
self
.
n_seqs
:
1
+
self
.
n_seqs
+
self
.
n_mit_mot
]
def
inner_mitmot_outs
(
self
):
def
inner_mitmot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
return
self
.
outputs
[:
n_taps
]
return
list_
outputs
[:
n_taps
]
def
outer_mitmot_outs
(
self
,
node
):
def
outer_mitmot_outs
(
self
,
list_outputs
):
return
node
.
outputs
[:
self
.
n_mit_mot
]
return
list_
outputs
[:
self
.
n_mit_mot
]
def
mitmot_taps
(
self
):
def
mitmot_taps
(
self
):
return
self
.
tap_array
[:
self
.
n_mit_mot
]
return
self
.
tap_array
[:
self
.
n_mit_mot
]
...
@@ -485,98 +574,98 @@ class Scan(PureOp):
...
@@ -485,98 +574,98 @@ class Scan(PureOp):
def
mitmot_out_taps
(
self
):
def
mitmot_out_taps
(
self
):
return
self
.
mit_mot_out_slices
[:
self
.
n_mit_mot
]
return
self
.
mit_mot_out_slices
[:
self
.
n_mit_mot
]
def
inner_mitsot
(
self
):
def
inner_mitsot
(
self
,
list_inputs
):
n_mitmot_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
n_mitmot_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
ntaps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
ntaps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_sot
)])
return
self
.
inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
return
list_
inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
self
.
n_seqs
+
ntaps_upto_sit_sot
]
self
.
n_seqs
+
ntaps_upto_sit_sot
]
def
outer_mitsot
(
self
,
node
):
def
outer_mitsot
(
self
,
list_inputs
):
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
return
node
.
inputs
[
offset
:
offset
+
self
.
n_mit_sot
]
return
list_
inputs
[
offset
:
offset
+
self
.
n_mit_sot
]
def
inner_mitsot_outs
(
self
):
def
inner_mitsot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
return
self
.
outputs
[
n_taps
:
n_taps
+
self
.
n_mit_sot
]
return
list_
outputs
[
n_taps
:
n_taps
+
self
.
n_mit_sot
]
def
outer_mitsot_outs
(
self
,
node
):
def
outer_mitsot_outs
(
self
,
list_outputs
):
return
node
.
outputs
[
self
.
n_mit_mot
:
return
list_
outputs
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
mitsot_taps
(
self
):
def
mitsot_taps
(
self
):
return
self
.
tap_array
[
self
.
n_mit_mot
:
return
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
inner_sitsot
(
self
):
def
inner_sitsot
(
self
,
list_inputs
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_sot
)])
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
return
self
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
list_
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot
(
self
,
node
):
def
outer_sitsot
(
self
,
list_inputs
):
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
return
node
.
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
list_
inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
inner_sitsot_outs
(
self
):
def
inner_sitsot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
offset
=
self
.
n_mit_sot
+
n_taps
return
self
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
list_
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot_outs
(
self
,
node
):
def
outer_sitsot_outs
(
self
,
list_outputs
):
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
return
node
.
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
return
list_
outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_nitsot
(
self
,
node
):
def
outer_nitsot
(
self
,
list_inputs
):
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
self
.
n_sit_sot
+
self
.
n_shared_outs
)
return
node
.
inputs
[
offset
:
offset
+
self
.
n_nit_sot
]
return
list_
inputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_nitsot_outs
(
self
):
def
inner_nitsot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
return
self
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
return
list_
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
outer_nitsot_outs
(
self
,
node
):
def
outer_nitsot_outs
(
self
,
list_outputs
):
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
return
node
.
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
return
list_
outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_shared
(
self
):
def
inner_shared
(
self
,
list_inputs
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_sot
)])
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
return
self
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
list_
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
outer_shared
(
self
,
node
):
def
outer_shared
(
self
,
list_inputs
):
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
return
node
.
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
list_
inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_shared_outs
(
self
):
def
inner_shared_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
+
self
.
n_nit_sot
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
+
self
.
n_nit_sot
return
self
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
list_
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
outer_shared_outs
(
self
,
node
):
def
outer_shared_outs
(
self
,
list_outputs
):
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
self
.
n_nit_sot
)
return
node
.
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
return
list_
outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_non_seqs
(
self
):
def
inner_non_seqs
(
self
,
list_inputs
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
tap_array
[:(
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_sot
)])
offset
=
(
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
+
offset
=
(
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
self
.
n_shared_outs
)
return
self
.
inputs
[
offset
:]
return
list_
inputs
[
offset
:]
def
outer_non_seqs
(
self
,
node
):
def
outer_non_seqs
(
self
,
list_inputs
):
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
return
node
.
inputs
[
offset
:]
return
list_
inputs
[
offset
:]
def
execute
(
self
,
node
,
args
,
outs
):
def
execute
(
self
,
node
,
args
,
outs
):
"""
"""
...
...
theano/scan_module/scan_opt.py
浏览文件 @
9d9d9020
...
@@ -200,8 +200,7 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -200,8 +200,7 @@ class PushOutNonSeqScan(gof.Optimizer):
not
isinstance
(
nd
.
op
,
theano
.
compile
.
ViewOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
ViewOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
DeepCopyOp
)
and
not
isinstance
(
nd
.
op
,
theano
.
compile
.
DeepCopyOp
)
and
# and we didn't already looked at this node
# and we didn't already looked at this node
not
nd
in
to_remove
not
nd
in
to_remove
):
):
# We have a candidate node to removable
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
# Step 1. Reconstruct it on outside
...
@@ -317,12 +316,12 @@ def scan_make_inplace(node):
...
@@ -317,12 +316,12 @@ def scan_make_inplace(node):
info
[
'inplace'
]
=
True
info
[
'inplace'
]
=
True
# inputs corresponding to sequences and n_steps
# inputs corresponding to sequences and n_steps
ls_begin
=
node
.
inputs
[:
1
+
op
.
n_seqs
]
ls_begin
=
node
.
inputs
[:
1
+
op
.
n_seqs
]
ls
=
op
.
outer_mitmot
(
node
)
ls
=
op
.
outer_mitmot
(
node
.
inputs
)
ls
+=
op
.
outer_mitsot
(
node
)
ls
+=
op
.
outer_mitsot
(
node
.
inputs
)
ls
+=
op
.
outer_sitsot
(
node
)
ls
+=
op
.
outer_sitsot
(
node
.
inputs
)
ls_end
=
op
.
outer_shared
(
node
)
ls_end
=
op
.
outer_shared
(
node
.
inputs
)
ls_end
+=
op
.
outer_nitsot
(
node
)
ls_end
+=
op
.
outer_nitsot
(
node
.
inputs
)
ls_end
+=
op
.
outer_non_seqs
(
node
)
ls_end
+=
op
.
outer_non_seqs
(
node
.
inputs
)
n_outs
=
len
(
ls
)
n_outs
=
len
(
ls
)
for
idx
in
xrange
(
n_outs
):
for
idx
in
xrange
(
n_outs
):
if
ls
[
idx
]
in
ls
[:
idx
]:
if
ls
[
idx
]
in
ls
[:
idx
]:
...
@@ -717,8 +716,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -717,8 +716,7 @@ class ScanSaveMem(gof.Optimizer):
fslice
=
slice
(
fslice
=
slice
(
sanitize
(
cnf_slice
[
0
]
.
start
),
sanitize
(
cnf_slice
[
0
]
.
start
),
sanitize
(
cnf_slice
[
0
]
.
stop
),
sanitize
(
cnf_slice
[
0
]
.
stop
),
sanitize
(
cnf_slice
[
0
]
.
step
)
sanitize
(
cnf_slice
[
0
]
.
step
))
)
else
:
else
:
fslice
=
sanitize
(
cnf_slice
[
0
])
fslice
=
sanitize
(
cnf_slice
[
0
])
...
@@ -850,54 +848,54 @@ class ScanMerge(gof.Optimizer):
...
@@ -850,54 +848,54 @@ class ScanMerge(gof.Optimizer):
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Seq
# Seq
inner_ins
+=
rename
(
nd
.
op
.
inner_seqs
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_seqs
(
nd
.
op
.
inputs
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_seqs
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_seqs
(
nd
.
inputs
),
idx
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# MitMot
# MitMot
inner_ins
+=
rename
(
nd
.
op
.
inner_mitmot
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_mitmot
(
nd
.
op
.
inputs
),
idx
)
inner_outs
+=
nd
.
op
.
inner_mitmot_outs
()
inner_outs
+=
nd
.
op
.
inner_mitmot_outs
(
nd
.
op
.
outputs
)
info
[
'tap_array'
]
+=
nd
.
op
.
mitmot_taps
()
info
[
'tap_array'
]
+=
nd
.
op
.
mitmot_taps
()
info
[
'mit_mot_out_slices'
]
+=
nd
.
op
.
mitmot_out_taps
()
info
[
'mit_mot_out_slices'
]
+=
nd
.
op
.
mitmot_out_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitmot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_mitmot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitmot_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_mitmot_outs
(
nd
.
outputs
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# MitSot
# MitSot
inner_ins
+=
rename
(
nd
.
op
.
inner_mitsot
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_mitsot
(
nd
.
op
.
inputs
),
idx
)
inner_outs
+=
nd
.
op
.
inner_mitsot_outs
()
inner_outs
+=
nd
.
op
.
inner_mitsot_outs
(
nd
.
op
.
outputs
)
info
[
'tap_array'
]
+=
nd
.
op
.
mitsot_taps
()
info
[
'tap_array'
]
+=
nd
.
op
.
mitsot_taps
()
outer_ins
+=
rename
(
nd
.
op
.
outer_mitsot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_mitsot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_mitsot_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_mitsot_outs
(
nd
.
outputs
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# SitSot
# SitSot
inner_ins
+=
rename
(
nd
.
op
.
inner_sitsot
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_sitsot
(
nd
.
op
.
inputs
),
idx
)
info
[
'tap_array'
]
+=
[[
-
1
]
for
x
in
xrange
(
nd
.
op
.
n_sit_sot
)]
info
[
'tap_array'
]
+=
[[
-
1
]
for
x
in
xrange
(
nd
.
op
.
n_sit_sot
)]
inner_outs
+=
nd
.
op
.
inner_sitsot_outs
()
inner_outs
+=
nd
.
op
.
inner_sitsot_outs
(
nd
.
op
.
outputs
)
outer_ins
+=
rename
(
nd
.
op
.
outer_sitsot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_sitsot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_sitsot_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_sitsot_outs
(
nd
.
outputs
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Shared
# Shared
inner_ins
+=
rename
(
nd
.
op
.
inner_shared
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_shared
(
nd
.
op
.
inputs
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_shared
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_shared
(
nd
.
inputs
),
idx
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# NitSot
# NitSot
inner_outs
+=
nd
.
op
.
inner_nitsot_outs
()
inner_outs
+=
nd
.
op
.
inner_nitsot_outs
(
nd
.
op
.
outputs
)
outer_ins
+=
rename
(
nd
.
op
.
outer_nitsot
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_nitsot
(
nd
.
inputs
),
idx
)
outer_outs
+=
nd
.
op
.
outer_nitsot_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_nitsot_outs
(
nd
.
outputs
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Shared
# Shared
outer_outs
+=
nd
.
op
.
outer_shared_outs
(
nd
)
outer_outs
+=
nd
.
op
.
outer_shared_outs
(
nd
.
outputs
)
inner_outs
+=
nd
.
op
.
inner_shared_outs
()
inner_outs
+=
nd
.
op
.
inner_shared_outs
(
nd
.
op
.
outputs
)
for
idx
,
nd
in
enumerate
(
nodes
):
for
idx
,
nd
in
enumerate
(
nodes
):
# Non Seqs
# Non Seqs
inner_ins
+=
rename
(
nd
.
op
.
inner_non_seqs
(),
idx
)
inner_ins
+=
rename
(
nd
.
op
.
inner_non_seqs
(
nd
.
op
.
inputs
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_non_seqs
(
nd
),
idx
)
outer_ins
+=
rename
(
nd
.
op
.
outer_non_seqs
(
nd
.
inputs
),
idx
)
# Add back the number of steps
# Add back the number of steps
outer_ins
=
[
nodes
[
0
]
.
inputs
[
0
]]
+
outer_ins
outer_ins
=
[
nodes
[
0
]
.
inputs
[
0
]]
+
outer_ins
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论