Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d07219c7
提交
d07219c7
authored
1月 22, 2012
作者:
Razvan Pascanu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PEP8 fixes
Note sure it makes the file anymore readable, but at least I've tried.
上级
b873707b
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
198 行增加
和
186 行删除
+198
-186
scan_op.py
theano/scan_module/scan_op.py
+198
-186
没有找到文件。
theano/scan_module/scan_op.py
浏览文件 @
d07219c7
...
...
@@ -188,18 +188,21 @@ class Scan(PureOp):
'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)'
'of variable
%
s (argument number
%
d)'
' has dtype
%
s and
%
d dimension(s), while the result
of the
'
'
inner function for this output has dtype
%
s and
%
d
'
'
dimension(s). This could happen if the inner graph of
'
'
scan results in an upcast or downcast. Please make
'
'sure that you use dtypes consistently'
)
' has dtype
%
s and
%
d dimension(s), while the result '
'
of the inner function for this output has dtype
%
s
'
'
and
%
d dimension(s). This could happen if the inner
'
'
graph of scan results in an upcast or downcast.
'
'
Please make
sure that you use dtypes consistently'
)
# TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and inputs correspond)
# TODO assert the type(dtype, nbdim of self.inputs and
# inputs correspond)
#assert len(inputs) >= len(self.inputs)
# if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + self.info["n_nit_sot"]
# else:
# assert len(inputs) == len(self.inputs) + 1 + self.info["n_nit_sot"]
#if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + \
# self.info["n_nit_sot"]
#else:
# assert len(inputs) == len(self.inputs) + 1 + \
# self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
...
...
@@ -236,26 +239,27 @@ class Scan(PureOp):
self
.
mitmot_out_taps
(),
self
.
outer_mitmot
(
inputs
))):
for
k
in
xrange
(
len
(
itaps
)):
if
(
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
!=
outer_mitmot
.
type
.
dtype
or
inner_mitmot
[
ipos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
inner_mitmot
[
ipos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
str
(
outer_mitmot
),
argoffset
+
idx
,
outer_mitmot
.
type
.
dtype
,
str
(
inner_mitmot
[
ipos
+
k
]),
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
))
str
(
inner_mitmot
[
ipos
+
k
]),
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
))
ipos
+=
len
(
itaps
)
for
k
in
xrange
(
len
(
otaps
)):
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
\
outer_mitmot
.
type
.
dtype
or
inner_mitmot_outs
[
opos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
inner_mitmot_outs
[
opos
+
k
]
.
ndim
!=
\
outer_mitmot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
(
str
(
outer_mitmot
),
argoffset
+
idx
,
outer_mitmot
.
type
.
dtype
,
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
))
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
))
opos
+=
len
(
otaps
)
argoffset
+=
len
(
self
.
outer_mitmot
(
inputs
))
# Same checks as above but for outputs of type mit_sot
...
...
@@ -266,18 +270,18 @@ class Scan(PureOp):
self
.
outer_mitsot
(
inputs
),
self
.
inner_mitsot_outs
(
self
.
outputs
))):
for
k
in
xrange
(
len
(
itaps
)):
if
(
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
!=
\
outer_mitsot
.
type
.
dtype
or
inner_mitsots
[
ipos
+
k
]
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
inner_mitsots
[
ipos
+
k
]
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
str
(
outer_mitsot
),
argoffset
+
idx
,
outer_mitsot
.
type
.
dtype
,
otuer_mitsot
.
type
.
ndim
,
str
(
inner_mitsot
[
ipos
+
k
]),
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
,
inner_mitsots
[
ipos
+
k
]
.
type
.
ndim
))
str
(
inner_mitsot
[
ipos
+
k
]),
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
,
inner_mitsots
[
ipos
+
k
]
.
type
.
ndim
))
ipos
+=
len
(
itaps
)
if
(
inner_mitsot_out
.
type
.
dtype
!=
outer_mitsot
.
type
.
dtype
or
inner_mitsot_out
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
...
...
@@ -287,7 +291,7 @@ class Scan(PureOp):
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
ndim
,
inner_mitsot_out
.
type
.
dtype
,
inner_mitsot_out
.
type
.
ndim
))
inner_mitsot_out
.
type
.
ndim
))
argoffset
+=
len
(
self
.
outer_mitsot
(
inputs
))
# Same checks as above but for outputs of type sit_sot
...
...
@@ -314,7 +318,7 @@ class Scan(PureOp):
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
ndim
,
inner_sitsot_out
.
type
.
dtype
,
inner_sitsot_out
.
type
.
ndim
))
inner_sitsot_out
.
type
.
ndim
))
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
# Check that the shared variable and their update rule have the same
...
...
@@ -352,7 +356,7 @@ class Scan(PureOp):
inner_nonseq
.
type
.
ndim
!=
outer_nonseq
.
type
.
ndim
):
raise
ValueError
((
'Argument
%
s given to scan node does not'
' match its correspondance
%
s'
)
%
' match its correspondance
%
s'
)
%
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
# For every nit_sot input we get as input a int/uint that
...
...
@@ -1120,7 +1124,7 @@ class Scan(PureOp):
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if
self
.
as_while
:
scan_outs
=
[(
Shape_i
(
0
)(
o
),)
+
x
[
1
:]
scan_outs
=
[(
Shape_i
(
0
)(
o
),)
+
x
[
1
:]
for
o
,
x
in
izip
(
node
.
outputs
,
scan_outs
)]
return
scan_outs
...
...
@@ -1148,50 +1152,49 @@ class Scan(PureOp):
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
offset
+=
n_ins_mit_mot
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
,
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
outs_mit_sot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_sot
]
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
,
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
outs_mit_sot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_sot
]
offset
+=
n_ins_mit_sot
outs_sit_sot
=
self_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
outs_sit_sot
=
self_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
offset
+=
self
.
n_sit_sot
old_scan_shared_ins
=
self_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
out_offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_nit_sot
+
self
.
n_sit_sot
)
old_scan_shared_ins
=
self_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
out_offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_nit_sot
+
self
.
n_sit_sot
)
# shared variables as well as the condition
old_scan_shared_outs
=
self_outputs
[
out_offset
:]
arg_offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
old_scan_init
=
args
[
arg_offset
:
arg_offset
+
self
.
n_shared_outs
]
arg_offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
old_scan_init
=
args
[
arg_offset
:
arg_offset
+
self
.
n_shared_outs
]
offset
+=
self
.
n_shared_outs
other_args
=
self_inputs
[
offset
:]
# 4. Collect (possibly) differentiable inputs
diff_inputs
=
(
seqs
+
diff_inputs
=
(
seqs
+
outs_mit_mot
+
outs_mit_sot
+
outs_sit_sot
+
other_args
)
other_args
)
#args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs)
def
compute_gradient
(
y
,
g_y
):
gmp
=
gradient
.
grad_sources_inputs
(
[(
y
,
g_y
)],
diff_inputs
,
False
)
return
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
[(
y
,
g_y
)],
diff_inputs
,
False
)
return
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
# 6. clean the outputs (i.e. remove update rules)
end
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
end
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
clean_outputs
=
self_outputs
[:
end
]
g_outs_no_shared
=
g_outs
[:
end
]
...
...
@@ -1204,22 +1207,24 @@ class Scan(PureOp):
# slices of the input
prev_inner_gfn_outs
=
[]
zeros_like_diff_ins
=
[]
pos
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
pos
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
)
offset
=
len
(
args
)
-
len
(
other_args
)
-
pos
# 7.2. generate variables to represent previous steps of g_outs
for
idx
,
diff_in
in
enumerate
(
diff_inputs
):
for
idx
,
diff_in
in
enumerate
(
diff_inputs
):
prev_gfn_out
=
safe_new
(
diff_in
)
if
hasattr
(
diff_in
,
'name'
)
and
diff_in
.
name
:
prev_gfn_out
.
name
=
'g_prev_'
+
diff_in
.
name
if
hasattr
(
diff_in
,
'name'
)
and
diff_in
.
name
:
prev_gfn_out
.
name
=
'g_prev_'
+
diff_in
.
name
else
:
prev_gfn_out
.
name
=
'g_prev_'
+
str
(
idx
)
prev_inner_gfn_outs
.
append
(
prev_gfn_out
)
prev_gfn_out
.
name
=
'g_prev_'
+
str
(
idx
)
prev_inner_gfn_outs
.
append
(
prev_gfn_out
)
if
idx
<
pos
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
diff_in
))
else
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
args
[
idx
+
offset
]))
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
args
[
idx
+
offset
]))
# 7.3. compute gradients of the inputs given one output
for
dx
,
out
in
enumerate
(
clean_outputs
):
...
...
@@ -1227,32 +1232,30 @@ class Scan(PureOp):
###
#### I need to clip the gradient HERE !!
if
g_outs_no_shared
[
dx
]:
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
else
:
g_out_slices
.
append
(
None
)
if
getattr
(
out
,
'name'
,
None
)
is
not
None
:
inner_g_out
.
name
=
'g_'
+
out
.
name
if
getattr
(
out
,
'name'
,
None
)
is
not
None
:
inner_g_out
.
name
=
'g_'
+
out
.
name
else
:
inner_g_out
.
name
=
'g_'
+
str
(
dx
)
inner_g_out
.
name
=
'g_'
+
str
(
dx
)
inner_g_outs
.
append
(
inner_g_out
)
_g_out
=
inner_g_out
grad_outs
=
compute_gradient
(
out
,
_g_out
)
if
not
inner_gfn_outs
:
for
idx
,
gfn_out
in
enumerate
(
grad_outs
):
if
idx
>=
self
.
n_seqs
:
inner_gfn_outs
.
append
(
prev_inner_gfn_outs
[
idx
]
)
inner_gfn_outs
.
append
(
prev_inner_gfn_outs
[
idx
]
)
else
:
inner_gfn_outs
.
append
(
None
)
inner_gfn_outs
.
append
(
None
)
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
# (assume their gradient is 0)
for
i
,
(
x
,
y
)
in
enumerate
(
zip
(
grad_outs
,
inner_gfn_outs
)):
for
i
,
(
x
,
y
)
in
enumerate
(
zip
(
grad_outs
,
inner_gfn_outs
)):
if
x
and
y
:
inner_gfn_outs
[
i
]
=
x
+
y
inner_gfn_outs
[
i
]
=
x
+
y
elif
y
:
inner_gfn_outs
[
i
]
=
y
else
:
...
...
@@ -1276,28 +1279,27 @@ class Scan(PureOp):
g_outs
[
i
]
=
theano
.
tensor
.
constant
(
numpy
.
array
(
0
,
theano
.
config
.
floatX
))
## 10. Get your sequence in order for the scan:
n_seqs
=
(
self
.
n_seqs
+
n_seqs
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_nit_sot
)
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
inner_seqs
=
(
seqs
+
self
.
n_sit_sot
)
inner_seqs
=
(
seqs
+
outs_mit_mot
+
outs_mit_sot
+
outs_sit_sot
+
inner_g_outs
[
offset
:
offset
+
self
.
n_nit_sot
])
inner_g_outs
[
offset
:
offset
+
self
.
n_nit_sot
])
scan_seqs
=
[
x
[::
-
1
]
for
x
in
args
[
1
:
self
.
n_seqs
+
1
]]
scan_seqs
=
[
x
[::
-
1
]
for
x
in
args
[
1
:
self
.
n_seqs
+
1
]]
offset
=
0
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
):
mintap
=
numpy
.
min
(
self
.
tap_array
[
idx
])
maxtap
=
numpy
.
max
(
self
.
tap_array
[
idx
])
seq
=
scan_outputs
[
offset
+
idx
]
seq
=
scan_outputs
[
offset
+
idx
]
for
k
in
self
.
tap_array
[
idx
]:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
...
...
@@ -1307,25 +1309,25 @@ class Scan(PureOp):
dim_offset
=
0
if
maxtap
==
mintap
and
maxtap
!=
0
:
nw_seq
=
seq
[:
abs
(
maxtap
)]
elif
maxtap
-
k
!=
0
:
tmp
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
(
maxtap
-
k
+
1
)]
nw_seq
=
tmp
[::
-
1
]
elif
maxtap
-
k
!=
0
:
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
\
-
(
maxtap
-
k
+
1
)]
[::
-
1
]
else
:
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
1
][::
-
1
]
if
getattr
(
seq
,
'name'
,
None
)
is
not
None
:
nw_seq
.
name
=
seq
.
name
+
'[
%
d:]'
%
k
if
getattr
(
seq
,
'name'
,
None
)
is
not
None
:
nw_seq
.
name
=
seq
.
name
+
'[
%
d:]'
%
k
scan_seqs
.
append
(
nw_seq
)
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
seq
=
scan_outputs
[
offset
+
idx
][:
-
1
]
seq
=
scan_outputs
[
offset
+
idx
][:
-
1
]
scan_seqs
.
append
(
seq
[::
-
1
])
offset
=
(
self
.
n_mit_mot_outs
+
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
scan_seqs
+=
[
x
[::
-
1
]
for
x
in
g_outs
[
offset
:
offset
+
self
.
n_nit_sot
]]
self
.
n_sit_sot
)
scan_seqs
+=
[
x
[::
-
1
]
for
x
in
g_outs
[
offset
:
offset
+
self
.
n_nit_sot
]]
scan_mit_mot
=
[]
inner_mit_mot
=
[]
...
...
@@ -1338,80 +1340,79 @@ class Scan(PureOp):
n_mit_mot_ins
=
0
ins_pos
=
self
.
n_seqs
for
idx
in
xrange
(
self
.
n_mit_mot
):
scan_mit_mot
.
append
(
g_outs
[
idx
][::
-
1
]
)
scan_mit_mot
.
append
(
g_outs
[
idx
][::
-
1
]
)
mit_mot_taps
.
append
([])
mit_mot_out_slices
.
append
([])
for
jdx
in
xrange
(
len
(
self
.
mit_mot_out_slices
[
idx
])):
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
mit_mot_taps
[
idx
]
.
append
(
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
mit_mot_taps
[
idx
]
.
append
(
\
-
self
.
mit_mot_out_slices
[
idx
][
jdx
])
n_mit_mot_ins
+=
1
out_pos
+=
1
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx
])):
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
]
)
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
scan_mit_mot_outs
.
append
(
\
inner_gfn_outs
[
ins_pos
]
)
n_mit_mot_ins
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
mit_mot_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mit_mot_out_slices
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
]
)
mit_mot_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mit_mot_out_slices
[
idx
]
.
append
(
\
-
self
.
tap_array
[
idx
][
jdx
])
offset
=
self
.
n_mit_mot
for
idx
in
xrange
(
self
.
n_mit_sot
):
mit_mot_taps
.
append
([])
mit_mot_out_slices
.
append
([])
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
idx_tap
=
idx
+
self
.
n_mit_mot
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx_tap
])):
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
mit_mot_taps
[
idx
+
offset
]
.
append
(
-
self
.
tap_array
[
idx_tap
][
jdx
]
)
mit_mot_out_slices
[
idx
]
.
append
(
-
self
.
tap_array
[
idx_tap
][
jdx
]
)
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
]
)
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
mit_mot_taps
[
idx
+
offset
]
.
append
(
\
-
self
.
tap_array
[
idx_tap
][
jdx
])
mit_mot_out_slices
[
idx
]
.
append
(
\
-
self
.
tap_array
[
idx_tap
][
jdx
])
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
]
)
n_mit_mot_ins
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
out_pos
+=
1
n_mit_mot_ins
+=
1
mit_mot_taps
[
idx
+
offset
]
.
append
(
0
)
mit_mot_taps
[
idx
+
offset
]
.
append
(
0
)
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
mit_mot_taps
.
append
([
0
,
1
])
mit_mot_taps
.
append
([
0
,
1
])
mit_mot_out_slices
.
append
([
1
])
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
])
inner_mit_mot
+=
[
inner_g_outs
[
out_pos
]
,
prev_inner_gfn_outs
[
ins_pos
]
]
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
])
inner_mit_mot
+=
[
inner_g_outs
[
out_pos
],
prev_inner_gfn_outs
[
ins_pos
]
]
n_mit_mot_outs
+=
1
out_pos
+=
1
ins_pos
+=
1
n_mit_mot_ins
+=
2
n_nit_sot
=
self
.
n_seqs
scan_nit_sot_outs
=
inner_gfn_outs
[:
self
.
n_seqs
]
if
self
.
truncate_gradient
!=
-
1
:
if
self
.
truncate_gradient
!=
-
1
:
do_steps
=
tensor
.
minimum
(
args
[
0
],
self
.
truncate_gradient
)
else
:
do_steps
=
args
[
0
]
offset
=
(
self
.
n_seqs
+
n_ins_mit_sot
+
n_ins_mit_mot
+
self
.
n_sit_sot
)
offset
=
(
self
.
n_seqs
+
n_ins_mit_sot
+
n_ins_mit_mot
+
self
.
n_sit_sot
)
# Instead of shared outs use sit_sot
n_sitsot_outs
=
len
(
prev_inner_gfn_outs
[
offset
:])
scan_sitsot_ins
=
prev_inner_gfn_outs
[
offset
:]
scan_sitsot_init
=
[]
for
x
in
zeros_like_diff_ins
[
offset
:]:
shapes
=
[
x
.
shape
[
i
]
for
i
in
xrange
(
x
.
ndim
)]
empty
=
tensor
.
zeros
([
do_steps
+
1
]
+
shapes
,
empty
=
tensor
.
zeros
([
do_steps
+
1
]
+
shapes
,
dtype
=
x
.
dtype
)
scan_sitsot_init
.
append
(
empty
)
scan_sitsot_outs
=
inner_gfn_outs
[
offset
:]
...
...
@@ -1422,9 +1423,9 @@ class Scan(PureOp):
info
[
'n_mit_sot'
]
=
0
info
[
'tap_array'
]
=
tap_array
info
[
'gpu'
]
=
False
n_mit_mot
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
n_mit_mot
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
...
...
@@ -1443,55 +1444,55 @@ class Scan(PureOp):
n_mit_sot
=
0
n_sit_sot
=
0
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
scan_inputs
=
(
[
do_steps
]
+
scan_inputs
=
(
[
do_steps
]
+
scan_seqs
+
scan_mit_mot
+
scan_sitsot_init
+
old_scan_init
+
[
args
[
0
]
for
x
in
xrange
(
n_nit_sot
)
]
+
args
[
offset
:]
)
[
args
[
0
]
for
x
in
xrange
(
n_nit_sot
)
]
+
args
[
offset
:])
offset
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
offset
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
inner_other_args
=
self_inputs
[
offset
:]
inner_gfn_ins
=
(
inner_seqs
+
inner_gfn_ins
=
(
inner_seqs
+
inner_mit_mot
+
scan_sitsot_ins
+
old_scan_shared_ins
+
inner_other_args
)
inner_gfn_outs
=
(
scan_mit_mot_outs
+
inner_other_args
)
inner_gfn_outs
=
(
scan_mit_mot_outs
+
scan_sitsot_outs
+
scan_nit_sot_outs
+
old_scan_shared_outs
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
old_scan_shared_outs
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
outputs
=
local_op
(
*
scan_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
# Re-order the gradients correctly
gradients
=
[
None
]
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
n_sitsot_outs
)
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[
offset
:
offset
+
self
.
n_seqs
]]
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
n_sitsot_outs
)
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[
offset
:
offset
+
self
.
n_seqs
]]
end
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[:
end
]]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_shared_outs
)]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_nit_sot
)
]
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[:
end
]]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_shared_outs
)]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_nit_sot
)
]
begin
=
end
end
=
begin
+
n_sitsot_outs
...
...
@@ -1501,11 +1502,12 @@ class Scan(PureOp):
def
R_op
(
self
,
inputs
,
eval_points
):
# Step 0. Don't work on the orignal tensor variables
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
self
.
outputs
,
'_rop'
)
self
.
outputs
,
'_rop'
)
self_inputs
=
rval
[
0
]
self_outputs
=
rval
[
1
]
# Step 1. Compute the R_op of the inner function
inner_eval_points
=
[
scan_utils
.
safe_new
(
x
,
'_evalpoint'
)
for
x
in
self_inputs
]
inner_eval_points
=
[
scan_utils
.
safe_new
(
x
,
'_evalpoint'
)
for
x
in
self_inputs
]
if
self
.
as_while
:
rop_self_outputs
=
self_outputs
[:
-
1
]
else
:
...
...
@@ -1524,33 +1526,33 @@ class Scan(PureOp):
# evan point for the number of nit_sot which I think should just be
# ignored (?)
info
=
{}
info
[
'n_seqs'
]
=
self
.
n_seqs
*
2
info
[
'n_mit_sot'
]
=
self
.
n_mit_sot
*
2
info
[
'n_sit_sot'
]
=
self
.
n_sit_sot
*
2
info
[
'n_mit_mot'
]
=
self
.
n_mit_mot
*
2
info
[
'n_nit_sot'
]
=
self
.
n_nit_sot
*
2
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
*
2
info
[
'n_seqs'
]
=
self
.
n_seqs
*
2
info
[
'n_mit_sot'
]
=
self
.
n_mit_sot
*
2
info
[
'n_sit_sot'
]
=
self
.
n_sit_sot
*
2
info
[
'n_mit_mot'
]
=
self
.
n_mit_mot
*
2
info
[
'n_nit_sot'
]
=
self
.
n_nit_sot
*
2
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
*
2
info
[
'gpu'
]
=
False
info
[
'as_while'
]
=
self
.
as_while
info
[
'profile'
]
=
self
.
profile
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
if
self
.
name
:
info
[
'name'
]
=
'rop_of_'
+
self
.
name
info
[
'name'
]
=
'rop_of_'
+
self
.
name
else
:
info
[
'name'
]
=
None
info
[
'mode'
]
=
self
.
mode
info
[
'inplace'
]
=
False
info
[
'mit_mot_out_slices'
]
=
self
.
mit_mot_out_slices
*
2
info
[
'mit_mot_out_slices'
]
=
self
.
mit_mot_out_slices
*
2
new_tap_array
=
[]
b
=
0
e
=
self
.
n_mit_mot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
b
=
e
e
+=
self
.
n_mit_sot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
b
=
e
e
+=
self
.
n_sit_sot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
info
[
'tap_array'
]
=
new_tap_array
# Sequences ...
...
...
@@ -1575,7 +1577,8 @@ class Scan(PureOp):
e
=
e
+
self
.
n_mit_sot
ib
=
ie
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]]))
self
.
tap_array
[
self
.
n_mit_mot
:
\
self
.
n_mit_mot
+
self
.
n_mit_sot
]]))
scan_mit_sot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_mit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
...
...
@@ -1598,8 +1601,7 @@ class Scan(PureOp):
# NIT_SOT sequences
b
=
e
e
=
e
+
self
.
n_nit_sot
scan_nit_sot
=
inputs
[
b
:
e
]
*
2
scan_nit_sot
=
inputs
[
b
:
e
]
*
2
# All other arguments
scan_other
=
inputs
[
e
:]
+
eval_points
[
e
:]
...
...
@@ -1625,13 +1627,13 @@ class Scan(PureOp):
e
=
e
+
self
.
n_shared_outs
inner_out_shared
=
self_outputs
[
b
:
e
]
+
rop_outs
[
b
:
e
]
inner_ins
=
(
inner_seqs
+
inner_ins
=
(
inner_seqs
+
inner_mit_mot
+
inner_mit_sot
+
inner_sit_sot
+
inner_shared
+
inner_other
)
inner_outs
=
(
inner_out_mit_mot
+
inner_other
)
inner_outs
=
(
inner_out_mit_mot
+
inner_out_mit_sot
+
inner_out_sit_sot
+
inner_out_nit_sot
+
...
...
@@ -1639,7 +1641,7 @@ class Scan(PureOp):
if
self
.
as_while
:
inner_outs
+=
[
self_outputs
[
-
1
]]
scan_inputs
=
(
[
inputs
[
0
]]
+
scan_inputs
=
(
[
inputs
[
0
]]
+
scan_seqs
+
scan_mit_mot
+
scan_mit_sot
+
...
...
@@ -1648,26 +1650,26 @@ class Scan(PureOp):
scan_nit_sot
+
scan_other
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
)
outputs
=
local_op
(
*
scan_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
# Select only the result of the R_op results
final_outs
=
[]
b
=
self
.
n_mit_mot
e
=
self
.
n_mit_mot
*
2
e
=
self
.
n_mit_mot
*
2
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_mit_sot
e
=
e
+
self
.
n_mit_sot
*
2
e
=
e
+
self
.
n_mit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_sit_sot
e
=
e
+
self
.
n_sit_sot
*
2
e
=
e
+
self
.
n_sit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_nit_sot
e
=
e
+
self
.
n_nit_sot
*
2
e
=
e
+
self
.
n_nit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_shared_outs
e
=
e
+
self
.
n_shared_outs
*
2
e
=
e
+
self
.
n_shared_outs
*
2
final_outs
+=
outputs
[
b
:
e
]
return
final_outs
...
...
@@ -1678,26 +1680,36 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time
,
apply_cimpl
,
message
,
outputs_size
,
other_time
):
# Scan overhead profile
if
any
([
isinstance
(
node
.
op
,
Scan
)
and
v
>
0
for
(
_
,
node
),
v
in
if
any
([
isinstance
(
node
.
op
,
Scan
)
and
v
>
0
for
(
_
,
node
),
v
in
apply_time
.
items
()]):
print
print
'Scan overhead:'
print
'<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(
%
scan op time)> <sub scan op time(
%
scan op time)> <node>'
print
(
'<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(
%
scan op time)> <sub scan '
'op time(
%
scan op time)> <node>'
)
total_super_scan_time
=
0
total_scan_fct_time
=
0
total_scan_op_time
=
0
for
(
_
,
node
),
v
in
apply_time
.
items
():
for
(
_
,
node
),
v
in
apply_time
.
items
():
if
isinstance
(
node
.
op
,
Scan
):
if
v
>
0
:
if
v
>
0
:
scan_fct_time
=
node
.
op
.
mode_instance
.
fn_time
scan_op_time
=
node
.
op
.
mode_instance
.
local_time
total_super_scan_time
+=
v
total_scan_fct_time
+=
scan_fct_time
total_scan_op_time
+=
scan_op_time
print
'
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
v
,
scan_fct_time
,
scan_op_time
,
scan_fct_time
/
v
*
100
,
scan_op_time
/
v
*
100
),
node
print
'
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
v
,
scan_fct_time
,
scan_op_time
,
scan_fct_time
/
v
*
100
,
scan_op_time
/
v
*
100
),
node
else
:
print
' The node took 0s, so we can not compute the overhead'
,
node
print
' total
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
total_super_scan_time
,
total_scan_fct_time
,
total_scan_op_time
,
total_scan_fct_time
/
total_super_scan_time
*
100
,
total_scan_op_time
/
total_super_scan_time
*
100
)
print
(
' The node took 0s, so we can not '
'compute the overhead'
),
node
print
' total
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
total_super_scan_time
,
total_scan_fct_time
,
total_scan_op_time
,
total_scan_fct_time
/
total_super_scan_time
*
100
,
total_scan_op_time
/
total_super_scan_time
*
100
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论