Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d07219c7
提交
d07219c7
authored
1月 22, 2012
作者:
Razvan Pascanu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PEP8 fixes
Note sure it makes the file anymore readable, but at least I've tried.
上级
b873707b
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
306 行增加
和
294 行删除
+306
-294
scan_op.py
theano/scan_module/scan_op.py
+306
-294
没有找到文件。
theano/scan_module/scan_op.py
浏览文件 @
d07219c7
...
@@ -188,18 +188,21 @@ class Scan(PureOp):
...
@@ -188,18 +188,21 @@ class Scan(PureOp):
'following error has been encountered: The '
'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)'
'initial state (outputs_info in scan nomenclature)'
'of variable
%
s (argument number
%
d)'
'of variable
%
s (argument number
%
d)'
' has dtype
%
s and
%
d dimension(s), while the result
of the
'
' has dtype
%
s and
%
d dimension(s), while the result '
'
inner function for this output has dtype
%
s and
%
d
'
'
of the inner function for this output has dtype
%
s
'
'
dimension(s). This could happen if the inner graph of
'
'
and
%
d dimension(s). This could happen if the inner
'
'
scan results in an upcast or downcast. Please make
'
'
graph of scan results in an upcast or downcast.
'
'sure that you use dtypes consistently'
)
'
Please make
sure that you use dtypes consistently'
)
# TODO make the assert exact
# TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and inputs correspond)
# TODO assert the type(dtype, nbdim of self.inputs and
# inputs correspond)
#assert len(inputs) >= len(self.inputs)
#assert len(inputs) >= len(self.inputs)
# if self.info['as_while']:
#if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + self.info["n_nit_sot"]
# assert len(inputs) == len(self.inputs) + 2 + \
# else:
# self.info["n_nit_sot"]
# assert len(inputs) == len(self.inputs) + 1 + self.info["n_nit_sot"]
#else:
# assert len(inputs) == len(self.inputs) + 1 + \
# self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors
# Flags that indicate which inputs are vectors
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
...
@@ -236,26 +239,27 @@ class Scan(PureOp):
...
@@ -236,26 +239,27 @@ class Scan(PureOp):
self
.
mitmot_out_taps
(),
self
.
mitmot_out_taps
(),
self
.
outer_mitmot
(
inputs
))):
self
.
outer_mitmot
(
inputs
))):
for
k
in
xrange
(
len
(
itaps
)):
for
k
in
xrange
(
len
(
itaps
)):
if
(
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
!=
outer_mitmot
.
type
.
dtype
or
outer_mitmot
.
type
.
dtype
or
inner_mitmot
[
ipos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
inner_mitmot
[
ipos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
outer_mitmot
),
str
(
outer_mitmot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitmot
.
type
.
dtype
,
outer_mitmot
.
type
.
dtype
,
str
(
inner_mitmot
[
ipos
+
k
]),
str
(
inner_mitmot
[
ipos
+
k
]),
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
))
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
))
ipos
+=
len
(
itaps
)
ipos
+=
len
(
itaps
)
for
k
in
xrange
(
len
(
otaps
)):
for
k
in
xrange
(
len
(
otaps
)):
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
\
outer_mitmot
.
type
.
dtype
or
outer_mitmot
.
type
.
dtype
or
inner_mitmot_outs
[
opos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
inner_mitmot_outs
[
opos
+
k
]
.
ndim
!=
\
raise
ValueError
(
err_msg2
%
outer_mitmot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
(
str
(
outer_mitmot
),
(
str
(
outer_mitmot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitmot
.
type
.
dtype
,
outer_mitmot
.
type
.
dtype
,
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
))
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
))
opos
+=
len
(
otaps
)
opos
+=
len
(
otaps
)
argoffset
+=
len
(
self
.
outer_mitmot
(
inputs
))
argoffset
+=
len
(
self
.
outer_mitmot
(
inputs
))
# Same checks as above but for outputs of type mit_sot
# Same checks as above but for outputs of type mit_sot
...
@@ -266,28 +270,28 @@ class Scan(PureOp):
...
@@ -266,28 +270,28 @@ class Scan(PureOp):
self
.
outer_mitsot
(
inputs
),
self
.
outer_mitsot
(
inputs
),
self
.
inner_mitsot_outs
(
self
.
outputs
))):
self
.
inner_mitsot_outs
(
self
.
outputs
))):
for
k
in
xrange
(
len
(
itaps
)):
for
k
in
xrange
(
len
(
itaps
)):
if
(
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
!=
\
outer_mitsot
.
type
.
dtype
or
outer_mitsot
.
type
.
dtype
or
inner_mitsots
[
ipos
+
k
]
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
inner_mitsots
[
ipos
+
k
]
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
outer_mitsot
),
str
(
outer_mitsot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
dtype
,
otuer_mitsot
.
type
.
ndim
,
otuer_mitsot
.
type
.
ndim
,
str
(
inner_mitsot
[
ipos
+
k
]),
str
(
inner_mitsot
[
ipos
+
k
]),
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
,
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
,
inner_mitsots
[
ipos
+
k
]
.
type
.
ndim
))
inner_mitsots
[
ipos
+
k
]
.
type
.
ndim
))
ipos
+=
len
(
itaps
)
ipos
+=
len
(
itaps
)
if
(
inner_mitsot_out
.
type
.
dtype
!=
outer_mitsot
.
type
.
dtype
or
if
(
inner_mitsot_out
.
type
.
dtype
!=
outer_mitsot
.
type
.
dtype
or
inner_mitsot_out
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
inner_mitsot_out
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
raise
ValueError
(
err_msg2
%
(
str
(
outer_mitsot
),
(
str
(
outer_mitsot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
ndim
,
outer_mitsot
.
type
.
ndim
,
inner_mitsot_out
.
type
.
dtype
,
inner_mitsot_out
.
type
.
dtype
,
inner_mitsot_out
.
type
.
ndim
))
inner_mitsot_out
.
type
.
ndim
))
argoffset
+=
len
(
self
.
outer_mitsot
(
inputs
))
argoffset
+=
len
(
self
.
outer_mitsot
(
inputs
))
# Same checks as above but for outputs of type sit_sot
# Same checks as above but for outputs of type sit_sot
...
@@ -308,13 +312,13 @@ class Scan(PureOp):
...
@@ -308,13 +312,13 @@ class Scan(PureOp):
inner_sitsot
.
type
.
ndim
))
inner_sitsot
.
type
.
ndim
))
if
(
inner_sitsot_out
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
or
if
(
inner_sitsot_out
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
or
inner_sitsot_out
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
inner_sitsot_out
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
raise
ValueError
(
err_msg2
%
(
str
(
outer_sitsot
),
(
str
(
outer_sitsot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
ndim
,
outer_sitsot
.
type
.
ndim
,
inner_sitsot_out
.
type
.
dtype
,
inner_sitsot_out
.
type
.
dtype
,
inner_sitsot_out
.
type
.
ndim
))
inner_sitsot_out
.
type
.
ndim
))
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
# Check that the shared variable and their update rule have the same
# Check that the shared variable and their update rule have the same
...
@@ -352,7 +356,7 @@ class Scan(PureOp):
...
@@ -352,7 +356,7 @@ class Scan(PureOp):
inner_nonseq
.
type
.
ndim
!=
outer_nonseq
.
type
.
ndim
):
inner_nonseq
.
type
.
ndim
!=
outer_nonseq
.
type
.
ndim
):
raise
ValueError
((
'Argument
%
s given to scan node does not'
raise
ValueError
((
'Argument
%
s given to scan node does not'
' match its correspondance
%
s'
)
%
' match its correspondance
%
s'
)
%
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
# For every nit_sot input we get as input a int/uint that
# For every nit_sot input we get as input a int/uint that
...
@@ -1120,7 +1124,7 @@ class Scan(PureOp):
...
@@ -1120,7 +1124,7 @@ class Scan(PureOp):
# if we are dealing with a repeat-until, then we do not know the
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
# leading dimension so we replace it for every entry with Shape_i
if
self
.
as_while
:
if
self
.
as_while
:
scan_outs
=
[(
Shape_i
(
0
)(
o
),)
+
x
[
1
:]
scan_outs
=
[(
Shape_i
(
0
)(
o
),)
+
x
[
1
:]
for
o
,
x
in
izip
(
node
.
outputs
,
scan_outs
)]
for
o
,
x
in
izip
(
node
.
outputs
,
scan_outs
)]
return
scan_outs
return
scan_outs
...
@@ -1147,79 +1151,80 @@ class Scan(PureOp):
...
@@ -1147,79 +1151,80 @@ class Scan(PureOp):
in
xrange
(
self
.
n_mit_mot
)])
in
xrange
(
self
.
n_mit_mot
)])
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
offset
+=
n_ins_mit_mot
offset
+=
n_ins_mit_mot
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
in
xrange
(
self
.
n_mit_mot
,
,
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
outs_mit_sot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_sot
]
outs_mit_sot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_sot
]
offset
+=
n_ins_mit_sot
offset
+=
n_ins_mit_sot
outs_sit_sot
=
self_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
outs_sit_sot
=
self_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
offset
+=
self
.
n_sit_sot
offset
+=
self
.
n_sit_sot
old_scan_shared_ins
=
self_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
old_scan_shared_ins
=
self_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
out_offset
=
(
self
.
n_mit_mot_outs
out_offset
=
(
self
.
n_mit_mot_outs
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_nit_sot
self
.
n_nit_sot
+
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
# shared variables as well as the condition
# shared variables as well as the condition
old_scan_shared_outs
=
self_outputs
[
out_offset
:]
old_scan_shared_outs
=
self_outputs
[
out_offset
:]
arg_offset
=
(
1
arg_offset
=
(
1
+
+
self
.
n_seqs
self
.
n_seqs
+
+
self
.
n_mit_mot
self
.
n_mit_mot
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
old_scan_init
=
args
[
arg_offset
:
arg_offset
+
self
.
n_shared_outs
]
old_scan_init
=
args
[
arg_offset
:
arg_offset
+
self
.
n_shared_outs
]
offset
+=
self
.
n_shared_outs
offset
+=
self
.
n_shared_outs
other_args
=
self_inputs
[
offset
:]
other_args
=
self_inputs
[
offset
:]
# 4. Collect (possibly) differentiable inputs
# 4. Collect (possibly) differentiable inputs
diff_inputs
=
(
seqs
+
diff_inputs
=
(
seqs
+
outs_mit_mot
+
outs_mit_mot
+
outs_mit_sot
+
outs_mit_sot
+
outs_sit_sot
+
outs_sit_sot
+
other_args
)
other_args
)
#args[-len(other_args):] )
#args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over
# 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs)
# the gradients with respect to all outputs)
def
compute_gradient
(
y
,
g_y
):
def
compute_gradient
(
y
,
g_y
):
gmp
=
gradient
.
grad_sources_inputs
(
gmp
=
gradient
.
grad_sources_inputs
(
[(
y
,
g_y
)],
diff_inputs
,
False
)
[(
y
,
g_y
)],
diff_inputs
,
False
)
return
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
return
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
# 6. clean the outputs (i.e. remove update rules)
# 6. clean the outputs (i.e. remove update rules)
end
=
(
self
.
n_mit_mot_outs
end
=
(
self
.
n_mit_mot_outs
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
self
.
n_sit_sot
+
+
self
.
n_nit_sot
)
self
.
n_nit_sot
)
clean_outputs
=
self_outputs
[:
end
]
clean_outputs
=
self_outputs
[:
end
]
g_outs_no_shared
=
g_outs
[:
end
]
g_outs_no_shared
=
g_outs
[:
end
]
# 7.1. empty lists to hold gradients
# 7.1. empty lists to hold gradients
# List of slices from outputs (used to compute the gradients)
# List of slices from outputs (used to compute the gradients)
inner_g_outs
=
[]
inner_g_outs
=
[]
g_out_slices
=
[]
g_out_slices
=
[]
# List of outputs of the gradient function
# List of outputs of the gradient function
inner_gfn_outs
=
[]
inner_gfn_outs
=
[]
# slices of the input
# slices of the input
prev_inner_gfn_outs
=
[]
prev_inner_gfn_outs
=
[]
zeros_like_diff_ins
=
[]
zeros_like_diff_ins
=
[]
pos
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
pos
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
offset
=
len
(
args
)
-
len
(
other_args
)
-
pos
offset
=
len
(
args
)
-
len
(
other_args
)
-
pos
# 7.2. generate variables to represent previous steps of g_outs
# 7.2. generate variables to represent previous steps of g_outs
for
idx
,
diff_in
in
enumerate
(
diff_inputs
):
for
idx
,
diff_in
in
enumerate
(
diff_inputs
):
prev_gfn_out
=
safe_new
(
diff_in
)
prev_gfn_out
=
safe_new
(
diff_in
)
if
hasattr
(
diff_in
,
'name'
)
and
diff_in
.
name
:
if
hasattr
(
diff_in
,
'name'
)
and
diff_in
.
name
:
prev_gfn_out
.
name
=
'g_prev_'
+
diff_in
.
name
prev_gfn_out
.
name
=
'g_prev_'
+
diff_in
.
name
else
:
else
:
prev_gfn_out
.
name
=
'g_prev_'
+
str
(
idx
)
prev_gfn_out
.
name
=
'g_prev_'
+
str
(
idx
)
prev_inner_gfn_outs
.
append
(
prev_gfn_out
)
prev_inner_gfn_outs
.
append
(
prev_gfn_out
)
if
idx
<
pos
:
if
idx
<
pos
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
diff_in
))
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
diff_in
))
else
:
else
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
args
[
idx
+
offset
]))
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
args
[
idx
+
offset
]))
# 7.3. compute gradients of the inputs given one output
# 7.3. compute gradients of the inputs given one output
for
dx
,
out
in
enumerate
(
clean_outputs
):
for
dx
,
out
in
enumerate
(
clean_outputs
):
...
@@ -1227,32 +1232,30 @@ class Scan(PureOp):
...
@@ -1227,32 +1232,30 @@ class Scan(PureOp):
###
###
#### I need to clip the gradient HERE !!
#### I need to clip the gradient HERE !!
if
g_outs_no_shared
[
dx
]:
if
g_outs_no_shared
[
dx
]:
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
else
:
else
:
g_out_slices
.
append
(
None
)
g_out_slices
.
append
(
None
)
if
getattr
(
out
,
'name'
,
None
)
is
not
None
:
if
getattr
(
out
,
'name'
,
None
)
is
not
None
:
inner_g_out
.
name
=
'g_'
+
out
.
name
inner_g_out
.
name
=
'g_'
+
out
.
name
else
:
else
:
inner_g_out
.
name
=
'g_'
+
str
(
dx
)
inner_g_out
.
name
=
'g_'
+
str
(
dx
)
inner_g_outs
.
append
(
inner_g_out
)
inner_g_outs
.
append
(
inner_g_out
)
_g_out
=
inner_g_out
_g_out
=
inner_g_out
grad_outs
=
compute_gradient
(
out
,
_g_out
)
grad_outs
=
compute_gradient
(
out
,
_g_out
)
if
not
inner_gfn_outs
:
if
not
inner_gfn_outs
:
for
idx
,
gfn_out
in
enumerate
(
grad_outs
):
for
idx
,
gfn_out
in
enumerate
(
grad_outs
):
if
idx
>=
self
.
n_seqs
:
if
idx
>=
self
.
n_seqs
:
inner_gfn_outs
.
append
(
prev_inner_gfn_outs
[
idx
]
)
inner_gfn_outs
.
append
(
prev_inner_gfn_outs
[
idx
]
)
else
:
else
:
inner_gfn_outs
.
append
(
None
)
inner_gfn_outs
.
append
(
None
)
# 7.4 Sum the gradients
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
# differentiable, for those we don't add them to the mix
# (assume their gradient is 0)
# (assume their gradient is 0)
for
i
,
(
x
,
y
)
in
enumerate
(
zip
(
grad_outs
,
inner_gfn_outs
)):
for
i
,
(
x
,
y
)
in
enumerate
(
zip
(
grad_outs
,
inner_gfn_outs
)):
if
x
and
y
:
if
x
and
y
:
inner_gfn_outs
[
i
]
=
x
+
y
inner_gfn_outs
[
i
]
=
x
+
y
elif
y
:
elif
y
:
inner_gfn_outs
[
i
]
=
y
inner_gfn_outs
[
i
]
=
y
else
:
else
:
...
@@ -1276,28 +1279,27 @@ class Scan(PureOp):
...
@@ -1276,28 +1279,27 @@ class Scan(PureOp):
g_outs
[
i
]
=
theano
.
tensor
.
constant
(
g_outs
[
i
]
=
theano
.
tensor
.
constant
(
numpy
.
array
(
0
,
theano
.
config
.
floatX
))
numpy
.
array
(
0
,
theano
.
config
.
floatX
))
## 10. Get your sequence in order for the scan:
## 10. Get your sequence in order for the scan:
n_seqs
=
(
self
.
n_seqs
+
n_seqs
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_mot
+
n_ins_mit_sot
+
n_ins_mit_sot
+
self
.
n_sit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
self
.
n_nit_sot
)
offset
=
(
self
.
n_mit_mot_outs
+
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
inner_seqs
=
(
seqs
+
inner_seqs
=
(
seqs
+
outs_mit_mot
+
outs_mit_mot
+
outs_mit_sot
+
outs_mit_sot
+
outs_sit_sot
+
outs_sit_sot
+
inner_g_outs
[
offset
:
offset
+
self
.
n_nit_sot
])
inner_g_outs
[
offset
:
offset
+
self
.
n_nit_sot
])
scan_seqs
=
[
x
[::
-
1
]
for
x
in
args
[
1
:
self
.
n_seqs
+
1
]]
scan_seqs
=
[
x
[::
-
1
]
for
x
in
args
[
1
:
self
.
n_seqs
+
1
]]
offset
=
0
offset
=
0
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
):
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
):
mintap
=
numpy
.
min
(
self
.
tap_array
[
idx
])
mintap
=
numpy
.
min
(
self
.
tap_array
[
idx
])
maxtap
=
numpy
.
max
(
self
.
tap_array
[
idx
])
maxtap
=
numpy
.
max
(
self
.
tap_array
[
idx
])
seq
=
scan_outputs
[
offset
+
idx
]
seq
=
scan_outputs
[
offset
+
idx
]
for
k
in
self
.
tap_array
[
idx
]:
for
k
in
self
.
tap_array
[
idx
]:
# We cut the sequence such that seq[i] to correspond to
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
# seq[i-k]
...
@@ -1307,205 +1309,205 @@ class Scan(PureOp):
...
@@ -1307,205 +1309,205 @@ class Scan(PureOp):
dim_offset
=
0
dim_offset
=
0
if
maxtap
==
mintap
and
maxtap
!=
0
:
if
maxtap
==
mintap
and
maxtap
!=
0
:
nw_seq
=
seq
[:
abs
(
maxtap
)]
nw_seq
=
seq
[:
abs
(
maxtap
)]
elif
maxtap
-
k
!=
0
:
elif
maxtap
-
k
!=
0
:
tmp
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
(
maxtap
-
k
+
1
)]
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
\
nw_seq
=
tmp
[::
-
1
]
-
(
maxtap
-
k
+
1
)]
[::
-
1
]
else
:
else
:
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
1
][::
-
1
]
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
1
][::
-
1
]
if
getattr
(
seq
,
'name'
,
None
)
is
not
None
:
if
getattr
(
seq
,
'name'
,
None
)
is
not
None
:
nw_seq
.
name
=
seq
.
name
+
'[
%
d:]'
%
k
nw_seq
.
name
=
seq
.
name
+
'[
%
d:]'
%
k
scan_seqs
.
append
(
nw_seq
)
scan_seqs
.
append
(
nw_seq
)
offset
+=
self
.
n_mit_sot
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
for
idx
in
xrange
(
self
.
n_sit_sot
):
seq
=
scan_outputs
[
offset
+
idx
][:
-
1
]
seq
=
scan_outputs
[
offset
+
idx
][:
-
1
]
scan_seqs
.
append
(
seq
[::
-
1
])
scan_seqs
.
append
(
seq
[::
-
1
])
offset
=
(
self
.
n_mit_mot_outs
+
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
scan_seqs
+=
[
x
[::
-
1
]
for
x
in
scan_seqs
+=
[
x
[::
-
1
]
for
x
in
g_outs
[
offset
:
offset
+
self
.
n_nit_sot
]]
g_outs
[
offset
:
offset
+
self
.
n_nit_sot
]]
scan_mit_mot
=
[]
scan_mit_mot
=
[]
inner_mit_mot
=
[]
inner_mit_mot
=
[]
scan_mit_mot_outs
=
[]
scan_mit_mot_outs
=
[]
mit_mot_taps
=
[]
mit_mot_taps
=
[]
mit_mot_out_slices
=
[]
mit_mot_out_slices
=
[]
out_pos
=
0
out_pos
=
0
ins_pos
=
n_seqs
ins_pos
=
n_seqs
n_mit_mot_outs
=
0
n_mit_mot_outs
=
0
n_mit_mot_ins
=
0
n_mit_mot_ins
=
0
ins_pos
=
self
.
n_seqs
ins_pos
=
self
.
n_seqs
for
idx
in
xrange
(
self
.
n_mit_mot
):
for
idx
in
xrange
(
self
.
n_mit_mot
):
scan_mit_mot
.
append
(
g_outs
[
idx
][::
-
1
]
)
scan_mit_mot
.
append
(
g_outs
[
idx
][::
-
1
]
)
mit_mot_taps
.
append
([])
mit_mot_taps
.
append
([])
mit_mot_out_slices
.
append
([])
mit_mot_out_slices
.
append
([])
for
jdx
in
xrange
(
len
(
self
.
mit_mot_out_slices
[
idx
])):
for
jdx
in
xrange
(
len
(
self
.
mit_mot_out_slices
[
idx
])):
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
mit_mot_taps
[
idx
]
.
append
(
mit_mot_taps
[
idx
]
.
append
(
\
-
self
.
mit_mot_out_slices
[
idx
][
jdx
])
-
self
.
mit_mot_out_slices
[
idx
][
jdx
])
n_mit_mot_ins
+=
1
n_mit_mot_ins
+=
1
out_pos
+=
1
out_pos
+=
1
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx
])):
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx
])):
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
scan_mit_mot_outs
.
append
(
scan_mit_mot_outs
.
append
(
\
inner_gfn_outs
[
ins_pos
]
)
inner_gfn_outs
[
ins_pos
]
)
n_mit_mot_ins
+=
1
n_mit_mot_ins
+=
1
ins_pos
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
n_mit_mot_outs
+=
1
mit_mot_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mit_mot_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mit_mot_out_slices
[
idx
]
.
append
(
mit_mot_out_slices
[
idx
]
.
append
(
\
-
self
.
tap_array
[
idx
][
jdx
]
)
-
self
.
tap_array
[
idx
][
jdx
])
offset
=
self
.
n_mit_mot
offset
=
self
.
n_mit_mot
for
idx
in
xrange
(
self
.
n_mit_sot
):
for
idx
in
xrange
(
self
.
n_mit_sot
):
mit_mot_taps
.
append
([])
mit_mot_taps
.
append
([])
mit_mot_out_slices
.
append
([])
mit_mot_out_slices
.
append
([])
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
idx_tap
=
idx
+
self
.
n_mit_mot
idx_tap
=
idx
+
self
.
n_mit_mot
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx_tap
])):
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx_tap
])):
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
mit_mot_taps
[
idx
+
offset
]
.
append
(
mit_mot_taps
[
idx
+
offset
]
.
append
(
\
-
self
.
tap_array
[
idx_tap
][
jdx
]
)
-
self
.
tap_array
[
idx_tap
][
jdx
])
mit_mot_out_slices
[
idx
]
.
append
(
mit_mot_out_slices
[
idx
]
.
append
(
\
-
self
.
tap_array
[
idx_tap
][
jdx
]
)
-
self
.
tap_array
[
idx_tap
][
jdx
])
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
]
)
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
]
)
n_mit_mot_ins
+=
1
n_mit_mot_ins
+=
1
ins_pos
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
n_mit_mot_outs
+=
1
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
out_pos
+=
1
out_pos
+=
1
n_mit_mot_ins
+=
1
n_mit_mot_ins
+=
1
mit_mot_taps
[
idx
+
offset
]
.
append
(
0
)
mit_mot_taps
[
idx
+
offset
]
.
append
(
0
)
offset
+=
self
.
n_mit_sot
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
for
idx
in
xrange
(
self
.
n_sit_sot
):
mit_mot_taps
.
append
([
0
,
1
])
mit_mot_taps
.
append
([
0
,
1
])
mit_mot_out_slices
.
append
([
1
])
mit_mot_out_slices
.
append
([
1
])
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
])
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
])
inner_mit_mot
+=
[
inner_g_outs
[
out_pos
]
inner_mit_mot
+=
[
inner_g_outs
[
out_pos
],
,
prev_inner_gfn_outs
[
ins_pos
]
]
prev_inner_gfn_outs
[
ins_pos
]
]
n_mit_mot_outs
+=
1
n_mit_mot_outs
+=
1
out_pos
+=
1
out_pos
+=
1
ins_pos
+=
1
ins_pos
+=
1
n_mit_mot_ins
+=
2
n_mit_mot_ins
+=
2
n_nit_sot
=
self
.
n_seqs
n_nit_sot
=
self
.
n_seqs
scan_nit_sot_outs
=
inner_gfn_outs
[:
self
.
n_seqs
]
scan_nit_sot_outs
=
inner_gfn_outs
[:
self
.
n_seqs
]
if
self
.
truncate_gradient
!=
-
1
:
if
self
.
truncate_gradient
!=
-
1
:
do_steps
=
tensor
.
minimum
(
args
[
0
],
self
.
truncate_gradient
)
do_steps
=
tensor
.
minimum
(
args
[
0
],
self
.
truncate_gradient
)
else
:
else
:
do_steps
=
args
[
0
]
do_steps
=
args
[
0
]
offset
=
(
self
.
n_seqs
offset
=
(
self
.
n_seqs
+
+
n_ins_mit_sot
n_ins_mit_sot
+
+
n_ins_mit_mot
n_ins_mit_mot
+
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
# Instead of shared outs use sit_sot
# Instead of shared outs use sit_sot
n_sitsot_outs
=
len
(
prev_inner_gfn_outs
[
offset
:])
n_sitsot_outs
=
len
(
prev_inner_gfn_outs
[
offset
:])
scan_sitsot_ins
=
prev_inner_gfn_outs
[
offset
:]
scan_sitsot_ins
=
prev_inner_gfn_outs
[
offset
:]
scan_sitsot_init
=
[]
scan_sitsot_init
=
[]
for
x
in
zeros_like_diff_ins
[
offset
:]:
for
x
in
zeros_like_diff_ins
[
offset
:]:
shapes
=
[
x
.
shape
[
i
]
for
i
in
xrange
(
x
.
ndim
)]
shapes
=
[
x
.
shape
[
i
]
for
i
in
xrange
(
x
.
ndim
)]
empty
=
tensor
.
zeros
([
do_steps
+
1
]
+
shapes
,
empty
=
tensor
.
zeros
([
do_steps
+
1
]
+
shapes
,
dtype
=
x
.
dtype
)
dtype
=
x
.
dtype
)
scan_sitsot_init
.
append
(
empty
)
scan_sitsot_init
.
append
(
empty
)
scan_sitsot_outs
=
inner_gfn_outs
[
offset
:]
scan_sitsot_outs
=
inner_gfn_outs
[
offset
:]
tap_array
=
mit_mot_taps
+
[[
-
1
]
for
k
in
tap_array
=
mit_mot_taps
+
[[
-
1
]
for
k
in
xrange
(
n_sitsot_outs
)]
xrange
(
n_sitsot_outs
)]
info
=
{}
info
=
{}
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_mit_sot'
]
=
0
info
[
'n_mit_sot'
]
=
0
info
[
'tap_array'
]
=
tap_array
info
[
'tap_array'
]
=
tap_array
info
[
'gpu'
]
=
False
info
[
'gpu'
]
=
False
n_mit_mot
=
(
self
.
n_mit_mot
n_mit_mot
=
(
self
.
n_mit_mot
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
info
[
'n_sit_sot'
]
=
n_sitsot_outs
info
[
'n_sit_sot'
]
=
n_sitsot_outs
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'as_while'
]
=
self
.
as_while
info
[
'as_while'
]
=
self
.
as_while
info
[
'profile'
]
=
self
.
profile
info
[
'profile'
]
=
self
.
profile
if
self
.
name
:
if
self
.
name
:
info
[
'name'
]
=
'grad_of_'
+
self
.
name
info
[
'name'
]
=
'grad_of_'
+
self
.
name
else
:
else
:
info
[
'name'
]
=
None
info
[
'name'
]
=
None
info
[
'mode'
]
=
self
.
mode
info
[
'mode'
]
=
self
.
mode
info
[
'inplace'
]
=
False
info
[
'inplace'
]
=
False
n_mit_sot
=
0
n_mit_sot
=
0
n_sit_sot
=
0
n_sit_sot
=
0
offset
=
(
1
offset
=
(
1
+
+
self
.
n_seqs
self
.
n_seqs
+
+
self
.
n_mit_mot
self
.
n_mit_mot
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
self
.
n_sit_sot
+
+
self
.
n_nit_sot
self
.
n_nit_sot
+
+
self
.
n_shared_outs
)
self
.
n_shared_outs
)
scan_inputs
=
(
[
do_steps
]
+
scan_inputs
=
(
[
do_steps
]
+
scan_seqs
+
scan_seqs
+
scan_mit_mot
+
scan_mit_mot
+
scan_sitsot_init
+
scan_sitsot_init
+
old_scan_init
+
old_scan_init
+
[
args
[
0
]
for
x
in
xrange
(
n_nit_sot
)
]
+
[
args
[
0
]
for
x
in
xrange
(
n_nit_sot
)
]
+
args
[
offset
:]
)
args
[
offset
:])
offset
=
(
self
.
n_seqs
offset
=
(
self
.
n_seqs
+
+
n_ins_mit_mot
n_ins_mit_mot
+
+
n_ins_mit_sot
n_ins_mit_sot
+
+
self
.
n_sit_sot
self
.
n_sit_sot
+
+
self
.
n_shared_outs
)
self
.
n_shared_outs
)
inner_other_args
=
self_inputs
[
offset
:]
inner_other_args
=
self_inputs
[
offset
:]
inner_gfn_ins
=
(
inner_seqs
+
inner_gfn_ins
=
(
inner_seqs
+
inner_mit_mot
+
inner_mit_mot
+
scan_sitsot_ins
+
scan_sitsot_ins
+
old_scan_shared_ins
+
old_scan_shared_ins
+
inner_other_args
)
inner_other_args
)
inner_gfn_outs
=
(
scan_mit_mot_outs
+
inner_gfn_outs
=
(
scan_mit_mot_outs
+
scan_sitsot_outs
+
scan_sitsot_outs
+
scan_nit_sot_outs
+
scan_nit_sot_outs
+
old_scan_shared_outs
)
old_scan_shared_outs
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
outputs
=
local_op
(
*
scan_inputs
)
outputs
=
local_op
(
*
scan_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
# Re-order the gradients correctly
# Re-order the gradients correctly
gradients
=
[
None
]
gradients
=
[
None
]
offset
=
(
self
.
n_mit_mot
offset
=
(
self
.
n_mit_mot
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
self
.
n_sit_sot
+
+
n_sitsot_outs
)
n_sitsot_outs
)
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[
offset
:
offset
+
self
.
n_seqs
]]
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[
offset
:
offset
+
self
.
n_seqs
]]
end
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
end
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[:
end
]]
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[:
end
]]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_shared_outs
)]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_shared_outs
)]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_nit_sot
)
]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_nit_sot
)
]
begin
=
end
begin
=
end
end
=
begin
+
n_sitsot_outs
end
=
begin
+
n_sitsot_outs
gradients
+=
[
x
[
-
1
]
for
x
in
outputs
[
begin
:
end
]]
gradients
+=
[
x
[
-
1
]
for
x
in
outputs
[
begin
:
end
]]
return
gradients
return
gradients
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
# Step 0. Don't work on the orignal tensor variables
# Step 0. Don't work on the orignal tensor variables
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
self
.
outputs
,
'_rop'
)
self
.
outputs
,
'_rop'
)
self_inputs
=
rval
[
0
]
self_inputs
=
rval
[
0
]
self_outputs
=
rval
[
1
]
self_outputs
=
rval
[
1
]
# Step 1. Compute the R_op of the inner function
# Step 1. Compute the R_op of the inner function
inner_eval_points
=
[
scan_utils
.
safe_new
(
x
,
'_evalpoint'
)
for
x
in
self_inputs
]
inner_eval_points
=
[
scan_utils
.
safe_new
(
x
,
'_evalpoint'
)
for
x
in
self_inputs
]
if
self
.
as_while
:
if
self
.
as_while
:
rop_self_outputs
=
self_outputs
[:
-
1
]
rop_self_outputs
=
self_outputs
[:
-
1
]
else
:
else
:
...
@@ -1524,82 +1526,82 @@ class Scan(PureOp):
...
@@ -1524,82 +1526,82 @@ class Scan(PureOp):
# evan point for the number of nit_sot which I think should just be
# evan point for the number of nit_sot which I think should just be
# ignored (?)
# ignored (?)
info
=
{}
info
=
{}
info
[
'n_seqs'
]
=
self
.
n_seqs
*
2
info
[
'n_seqs'
]
=
self
.
n_seqs
*
2
info
[
'n_mit_sot'
]
=
self
.
n_mit_sot
*
2
info
[
'n_mit_sot'
]
=
self
.
n_mit_sot
*
2
info
[
'n_sit_sot'
]
=
self
.
n_sit_sot
*
2
info
[
'n_sit_sot'
]
=
self
.
n_sit_sot
*
2
info
[
'n_mit_mot'
]
=
self
.
n_mit_mot
*
2
info
[
'n_mit_mot'
]
=
self
.
n_mit_mot
*
2
info
[
'n_nit_sot'
]
=
self
.
n_nit_sot
*
2
info
[
'n_nit_sot'
]
=
self
.
n_nit_sot
*
2
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
*
2
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
*
2
info
[
'gpu'
]
=
False
info
[
'gpu'
]
=
False
info
[
'as_while'
]
=
self
.
as_while
info
[
'as_while'
]
=
self
.
as_while
info
[
'profile'
]
=
self
.
profile
info
[
'profile'
]
=
self
.
profile
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
if
self
.
name
:
if
self
.
name
:
info
[
'name'
]
=
'rop_of_'
+
self
.
name
info
[
'name'
]
=
'rop_of_'
+
self
.
name
else
:
else
:
info
[
'name'
]
=
None
info
[
'name'
]
=
None
info
[
'mode'
]
=
self
.
mode
info
[
'mode'
]
=
self
.
mode
info
[
'inplace'
]
=
False
info
[
'inplace'
]
=
False
info
[
'mit_mot_out_slices'
]
=
self
.
mit_mot_out_slices
*
2
info
[
'mit_mot_out_slices'
]
=
self
.
mit_mot_out_slices
*
2
new_tap_array
=
[]
new_tap_array
=
[]
b
=
0
b
=
0
e
=
self
.
n_mit_mot
e
=
self
.
n_mit_mot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
b
=
e
b
=
e
e
+=
self
.
n_mit_sot
e
+=
self
.
n_mit_sot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
b
=
e
b
=
e
e
+=
self
.
n_sit_sot
e
+=
self
.
n_sit_sot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
info
[
'tap_array'
]
=
new_tap_array
info
[
'tap_array'
]
=
new_tap_array
# Sequences ...
# Sequences ...
b
=
1
b
=
1
ib
=
0
ib
=
0
e
=
1
+
self
.
n_seqs
e
=
1
+
self
.
n_seqs
ie
=
self
.
n_seqs
ie
=
self
.
n_seqs
scan_seqs
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
scan_seqs
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_seqs
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_seqs
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
# MIT_MOT sequences ...
# MIT_MOT sequences ...
b
=
e
b
=
e
e
=
e
+
self
.
n_mit_mot
e
=
e
+
self
.
n_mit_mot
ib
=
ie
ib
=
ie
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
]]))
self
.
tap_array
[:
self
.
n_mit_mot
]]))
scan_mit_mot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
scan_mit_mot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_mit_mot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_mit_mot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
# MIT_SOT sequences ...
# MIT_SOT sequences ...
b
=
e
b
=
e
e
=
e
+
self
.
n_mit_sot
e
=
e
+
self
.
n_mit_sot
ib
=
ie
ib
=
ie
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]]))
self
.
tap_array
[
self
.
n_mit_mot
:
\
scan_mit_sot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
self
.
n_mit_mot
+
self
.
n_mit_sot
]]))
scan_mit_sot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_mit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_mit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
#SIT_SOT sequences ...
#SIT_SOT sequences ...
b
=
e
b
=
e
e
=
e
+
self
.
n_sit_sot
e
=
e
+
self
.
n_sit_sot
ib
=
ie
ib
=
ie
ie
=
ie
+
self
.
n_sit_sot
ie
=
ie
+
self
.
n_sit_sot
scan_sit_sot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
scan_sit_sot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_sit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_sit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
#Shared outs ...
#Shared outs ...
b
=
e
b
=
e
e
=
e
+
self
.
n_shared_outs
e
=
e
+
self
.
n_shared_outs
ib
=
ie
ib
=
ie
ie
=
ie
+
self
.
n_shared_outs
ie
=
ie
+
self
.
n_shared_outs
scan_shared
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
scan_shared
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_shared
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_shared
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
# NIT_SOT sequences
# NIT_SOT sequences
b
=
e
b
=
e
e
=
e
+
self
.
n_nit_sot
e
=
e
+
self
.
n_nit_sot
scan_nit_sot
=
inputs
[
b
:
e
]
*
2
scan_nit_sot
=
inputs
[
b
:
e
]
*
2
# All other arguments
# All other arguments
scan_other
=
inputs
[
e
:]
+
eval_points
[
e
:]
scan_other
=
inputs
[
e
:]
+
eval_points
[
e
:]
...
@@ -1625,13 +1627,13 @@ class Scan(PureOp):
...
@@ -1625,13 +1627,13 @@ class Scan(PureOp):
e
=
e
+
self
.
n_shared_outs
e
=
e
+
self
.
n_shared_outs
inner_out_shared
=
self_outputs
[
b
:
e
]
+
rop_outs
[
b
:
e
]
inner_out_shared
=
self_outputs
[
b
:
e
]
+
rop_outs
[
b
:
e
]
inner_ins
=
(
inner_seqs
+
inner_ins
=
(
inner_seqs
+
inner_mit_mot
+
inner_mit_mot
+
inner_mit_sot
+
inner_mit_sot
+
inner_sit_sot
+
inner_sit_sot
+
inner_shared
+
inner_shared
+
inner_other
)
inner_other
)
inner_outs
=
(
inner_out_mit_mot
+
inner_outs
=
(
inner_out_mit_mot
+
inner_out_mit_sot
+
inner_out_mit_sot
+
inner_out_sit_sot
+
inner_out_sit_sot
+
inner_out_nit_sot
+
inner_out_nit_sot
+
...
@@ -1639,35 +1641,35 @@ class Scan(PureOp):
...
@@ -1639,35 +1641,35 @@ class Scan(PureOp):
if
self
.
as_while
:
if
self
.
as_while
:
inner_outs
+=
[
self_outputs
[
-
1
]]
inner_outs
+=
[
self_outputs
[
-
1
]]
scan_inputs
=
(
[
inputs
[
0
]]
+
scan_inputs
=
(
[
inputs
[
0
]]
+
scan_seqs
+
scan_seqs
+
scan_mit_mot
+
scan_mit_mot
+
scan_mit_sot
+
scan_mit_sot
+
scan_sit_sot
+
scan_sit_sot
+
scan_shared
+
scan_shared
+
scan_nit_sot
+
scan_nit_sot
+
scan_other
)
scan_other
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
)
outputs
=
local_op
(
*
scan_inputs
)
outputs
=
local_op
(
*
scan_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
# Select only the result of the R_op results
# Select only the result of the R_op results
final_outs
=
[]
final_outs
=
[]
b
=
self
.
n_mit_mot
b
=
self
.
n_mit_mot
e
=
self
.
n_mit_mot
*
2
e
=
self
.
n_mit_mot
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_mit_sot
b
=
e
+
self
.
n_mit_sot
e
=
e
+
self
.
n_mit_sot
*
2
e
=
e
+
self
.
n_mit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_sit_sot
b
=
e
+
self
.
n_sit_sot
e
=
e
+
self
.
n_sit_sot
*
2
e
=
e
+
self
.
n_sit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_nit_sot
b
=
e
+
self
.
n_nit_sot
e
=
e
+
self
.
n_nit_sot
*
2
e
=
e
+
self
.
n_nit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_shared_outs
b
=
e
+
self
.
n_shared_outs
e
=
e
+
self
.
n_shared_outs
*
2
e
=
e
+
self
.
n_shared_outs
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
return
final_outs
return
final_outs
...
@@ -1678,26 +1680,36 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
...
@@ -1678,26 +1680,36 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time
,
apply_cimpl
,
message
,
outputs_size
,
apply_time
,
apply_cimpl
,
message
,
outputs_size
,
other_time
):
other_time
):
# Scan overhead profile
# Scan overhead profile
if
any
([
isinstance
(
node
.
op
,
Scan
)
and
v
>
0
for
(
_
,
node
),
v
in
if
any
([
isinstance
(
node
.
op
,
Scan
)
and
v
>
0
for
(
_
,
node
),
v
in
apply_time
.
items
()]):
apply_time
.
items
()]):
print
print
print
'Scan overhead:'
print
'Scan overhead:'
print
'<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(
%
scan op time)> <sub scan op time(
%
scan op time)> <node>'
print
(
'<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(
%
scan op time)> <sub scan '
'op time(
%
scan op time)> <node>'
)
total_super_scan_time
=
0
total_super_scan_time
=
0
total_scan_fct_time
=
0
total_scan_fct_time
=
0
total_scan_op_time
=
0
total_scan_op_time
=
0
for
(
_
,
node
),
v
in
apply_time
.
items
():
for
(
_
,
node
),
v
in
apply_time
.
items
():
if
isinstance
(
node
.
op
,
Scan
):
if
isinstance
(
node
.
op
,
Scan
):
if
v
>
0
:
if
v
>
0
:
scan_fct_time
=
node
.
op
.
mode_instance
.
fn_time
scan_fct_time
=
node
.
op
.
mode_instance
.
fn_time
scan_op_time
=
node
.
op
.
mode_instance
.
local_time
scan_op_time
=
node
.
op
.
mode_instance
.
local_time
total_super_scan_time
+=
v
total_super_scan_time
+=
v
total_scan_fct_time
+=
scan_fct_time
total_scan_fct_time
+=
scan_fct_time
total_scan_op_time
+=
scan_op_time
total_scan_op_time
+=
scan_op_time
print
'
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
print
'
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
v
,
scan_fct_time
,
scan_op_time
,
scan_fct_time
/
v
*
100
,
v
,
scan_op_time
/
v
*
100
),
node
scan_fct_time
,
scan_op_time
,
scan_fct_time
/
v
*
100
,
scan_op_time
/
v
*
100
),
node
else
:
else
:
print
' The node took 0s, so we can not compute the overhead'
,
node
print
(
' The node took 0s, so we can not '
print
' total
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
'compute the overhead'
),
node
total_super_scan_time
,
total_scan_fct_time
,
total_scan_op_time
,
total_scan_fct_time
/
total_super_scan_time
*
100
,
total_scan_op_time
/
total_super_scan_time
*
100
)
print
' total
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
total_super_scan_time
,
total_scan_fct_time
,
total_scan_op_time
,
total_scan_fct_time
/
total_super_scan_time
*
100
,
total_scan_op_time
/
total_super_scan_time
*
100
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论