Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
3a85ea97
提交
3a85ea97
authored
10月 11, 2012
作者:
Razvan Pascanu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
new implementation of grad method
上级
25d8089a
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
226 行增加
和
302 行删除
+226
-302
scan_op.py
theano/scan_module/scan_op.py
+226
-302
没有找到文件。
theano/scan_module/scan_op.py
浏览文件 @
3a85ea97
...
...
@@ -34,7 +34,7 @@ from theano.gradient import DisconnectedType
from
theano.compile.profiling
import
ScanProfileStats
import
scan_utils
from
scan_utils
import
safe_new
from
scan_utils
import
safe_new
,
forced_replace
# Logging function for sending warning or info
_logger
=
logging
.
getLogger
(
'theano.scan_module.scan_op'
)
...
...
@@ -1194,198 +1194,138 @@ class Scan(PureOp):
for
o
,
x
in
izip
(
node
.
outputs
,
scan_outs
)]
return
scan_outs
### GRAD FUNCTION
def
grad
(
self
,
args
,
g_outs
):
# This discards information about whether incoming gradients are 0
# or disconnected from the cost
# TODO: upgrade scan op to report disconnection correctly
def
strip_disconnected
(
g
):
if
isinstance
(
g
.
type
,
DisconnectedType
):
return
None
return
g
g_outs
=
[
strip_disconnected
(
g
)
for
g
in
g_outs
]
# 1. forward pass - get the outputs after applying scan
scan_outputs
=
self
(
*
args
)
# 2. make sure they are given as a list
if
not
(
type
(
scan_outputs
)
in
(
list
,
tuple
)):
scan_outputs
=
[
scan_outputs
]
# 3. un-group / unzip the inputs
# Note ! We don't want to use the actual same variable as the ones
# used by the original scan, rather create clones of them
def
get_input_pos
(
self
,
output_index
):
ipos
=
self
.
n_seqs
opos
=
output_index
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
if
len
(
otaps
)
>
opos
:
return
ipos
else
:
opos
=
opos
-
len
(
otaps
)
ipos
+=
len
(
itaps
)
for
dx
,
taps
in
enumerate
(
self
.
mitsot_taps
()):
if
opos
==
0
:
return
ipos
else
:
opos
=
opos
-
1
ipos
+=
len
(
taps
)
if
opos
<
self
.
info
[
'n_sit_sot'
]:
return
ipos
+
opos
else
:
return
-
1
def
get_output_pos
(
self
,
input_index
):
ipos
=
input_index
opos
=
0
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
if
len
(
itaps
)
>
ipos
:
return
opos
else
:
opos
+=
len
(
otaps
)
ipos
-=
len
(
itaps
)
for
dx
,
taps
in
enumerate
(
self
.
mitsot_taps
()):
if
len
(
taps
)
>
ipos
:
return
opos
else
:
opos
+=
1
ipos
-=
len
(
taps
)
if
ipos
<
self
.
info
[
'n_sit_sot'
]:
return
ipos
+
opos
else
:
return
-
1
def
get_output_slice_idx
(
self
,
output_index
):
ipos
=
0
opos
=
output_index
for
otaps
in
zip
(
self
.
mitmot_out_taps
()):
if
len
(
otaps
)
>
0
:
return
ipos
else
:
opos
=
opos
-
1
ipos
+=
len
(
otaps
)
return
ipos
+
opos
### GRAD FUNCTION
def
grad
(
self
,
inputs
,
dC_douts
):
outs
=
self
(
*
inputs
)
if
not
isinstance
(
outs
,
(
list
,
tuple
)):
outs
=
[
outs
]
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
self
.
outputs
,
'_grad'
)
self
.
outputs
)
self_inputs
=
rval
[
0
]
self_outputs
=
rval
[
1
]
#differentiable inputs
diff_inputs
=
(
self
.
inner_seqs
(
self_inputs
)
+
self
.
inner_mitmot
(
self_inputs
)
+
self
.
inner_mitsot
(
self_inputs
)
+
self
.
inner_sitsot
(
self_inputs
)
+
self
.
inner_non_seqs
(
self_inputs
))
diff_outputs
=
(
self
.
inner_mitmot_outs
(
self_outputs
)
+
self
.
inner_mitsot_outs
(
self_outputs
)
+
self
.
inner_sitsot_outs
(
self_outputs
)
+
self
.
inner_nitsot_outs
(
self_outputs
))
seqs
=
self_inputs
[:
self
.
n_seqs
]
offset
=
self
.
n_seqs
n_ins_mit_mot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
)])
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
offset
+=
n_ins_mit_mot
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
,
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
outs_mit_sot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_sot
]
offset
+=
n_ins_mit_sot
outs_sit_sot
=
self_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
offset
+=
self
.
n_sit_sot
old_scan_shared_ins
=
self_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
out_offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_nit_sot
+
self
.
n_sit_sot
)
# shared variables as well as the condition
old_scan_shared_outs
=
self_outputs
[
out_offset
:]
arg_offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
old_scan_init
=
args
[
arg_offset
:
arg_offset
+
self
.
n_shared_outs
]
offset
+=
self
.
n_shared_outs
other_args
=
self_inputs
[
offset
:]
# 4. Collect (possibly) differentiable inputs
diff_inputs
=
(
seqs
+
outs_mit_mot
+
outs_mit_sot
+
outs_sit_sot
+
other_args
)
#args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs)
def
compute_gradient
(
y
,
g_y
):
gmp
=
gradient
.
grad_sources_inputs
(
[(
y
,
g_y
)],
diff_inputs
)
return
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
# 6. clean the outputs (i.e. remove update rules)
end
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
clean_outputs
=
self_outputs
[:
end
]
g_outs_no_shared
=
g_outs
[:
end
]
# 7.1. empty lists to hold gradients
# List of slices from outputs (used to compute the gradients)
inner_g_outs
=
[]
g_out_slices
=
[]
# List of outputs of the gradient function
inner_gfn_outs
=
[]
# slices of the input
prev_inner_gfn_outs
=
[]
zeros_like_diff_ins
=
[]
pos
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
)
offset
=
len
(
args
)
-
len
(
other_args
)
-
pos
# 7.2. generate variables to represent previous steps of g_outs
for
idx
,
diff_in
in
enumerate
(
diff_inputs
):
prev_gfn_out
=
safe_new
(
diff_in
)
if
hasattr
(
diff_in
,
'name'
)
and
diff_in
.
name
:
prev_gfn_out
.
name
=
'g_prev_'
+
diff_in
.
name
else
:
prev_gfn_out
.
name
=
'g_prev_'
+
str
(
idx
)
prev_inner_gfn_outs
.
append
(
prev_gfn_out
)
if
idx
<
pos
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
diff_in
))
else
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
args
[
idx
+
offset
]))
# 7.3. compute gradients of the inputs given one output
for
dx
,
out
in
enumerate
(
clean_outputs
):
if
g_outs
[
dx
]
!=
None
:
inner_g_out
=
safe_new
(
g_outs
[
dx
][
0
])
else
:
# We do not have a gradient on this output so we need a
# placeholder, which for now has the same dtype as the
# output
inner_g_out
=
safe_new
(
out
)
###
#### I need to clip the gradient HERE !!
if
g_outs_no_shared
[
dx
]:
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
dXt_inps
=
[
None
for
inp
in
diff_inputs
]
dXtp1_dXts
=
[]
Xts
=
[]
for
idx
,
Xt
in
enumerate
(
diff_outputs
):
# We are looking for x[t-1] for a given x[t]
if
idx
>=
self
.
n_mit_mot_outs
:
Xt_placeholder
=
Xt
.
type
()
Xts
.
append
(
Xt_placeholder
)
Xtm1_pos
=
self
.
get_input_pos
(
idx
)
if
Xtm1_pos
>=
0
:
Xtm1
=
self_inputs
[
Xtm1_pos
]
# It is possible that X[t] is not actually a function of
# x[t-1], case in which we can not rely on this information
try
:
tmp
=
tensor
.
grad
(
Xt
.
sum
(),
Xtm1
)
except
ValueError
:
tmp
=
Xt
dXtp1_dXt
=
safe_new
(
tmp
)
else
:
g_out_slices
.
append
(
None
)
if
getattr
(
out
,
'name'
,
None
)
is
not
None
:
inner_g_out
.
name
=
'g_'
+
out
.
name
if
isinstance
(
dC_douts
[
idx
]
.
type
,
DisconnectedType
):
continue
dXtp1_dXt
=
safe_new
(
dC_douts
[
idx
][
0
])
dXtp1_dXts
.
append
(
dXtp1_dXt
)
_dXt_inps
=
compute_gradient
(
Xt
,
dXtp1_dXt
)
for
jdx
in
xrange
(
len
(
_dXt_inps
)):
if
dXt_inps
[
jdx
]
is
None
:
dXt_inps
[
jdx
]
=
_dXt_inps
[
jdx
]
elif
_dXt_inps
[
jdx
]:
dXt_inps
[
jdx
]
+=
_dXt_inps
[
jdx
]
# mask inputs that get no gradients
for
dx
in
xrange
(
len
(
dXt_inps
)):
if
not
dXt_inps
[
dx
]:
dXt_inps
[
dx
]
=
tensor
.
zeros_like
(
diff_inputs
[
dx
])
else
:
inner_g_out
.
name
=
'g_'
+
str
(
dx
)
inner_g_outs
.
append
(
inner_g_out
)
_g_out
=
inner_g_out
grad_outs
=
compute_gradient
(
out
,
_g_out
)
if
not
inner_gfn_outs
:
for
idx
,
gfn_out
in
enumerate
(
grad_outs
):
if
idx
>=
self
.
n_seqs
:
inner_gfn_outs
.
append
(
prev_inner_gfn_outs
[
idx
])
else
:
inner_gfn_outs
.
append
(
None
)
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
# (assume their gradient is 0)
for
i
,
(
x
,
y
)
in
enumerate
(
zip
(
grad_outs
,
inner_gfn_outs
)):
if
x
and
y
:
inner_gfn_outs
[
i
]
=
x
+
y
elif
y
:
inner_gfn_outs
[
i
]
=
y
else
:
inner_gfn_outs
[
i
]
=
x
## 8. Mask the outputs that are not differentiable
# backwards pass
for
i
in
xrange
(
len
(
inner_gfn_outs
)):
if
inner_gfn_outs
[
i
]
is
None
:
inner_gfn_outs
[
i
]
=
tensor
.
zeros_like
(
diff_inputs
[
i
])
## 9. Mask the g_outs that are Nones :
for
i
,
out
in
enumerate
(
scan_outputs
):
if
g_outs
[
i
]
is
None
:
try
:
# this try is for catching non ndarray inputs (random
# states) it is more of a safety check ( all random
# states should be after n_outs_not_shared ...
g_outs
[
i
]
=
tensor
.
zeros_like
(
scan_outputs
[
i
])
except
Exception
:
g_outs
[
i
]
=
theano
.
tensor
.
constant
(
numpy
.
array
(
0
,
theano
.
config
.
floatX
))
## 10. Get your sequence in order for the scan:
n_seqs
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
inner_seqs
=
(
seqs
+
outs_mit_mot
+
outs_mit_sot
+
outs_sit_sot
+
inner_g_outs
[
offset
:
offset
+
self
.
n_nit_sot
])
scan_seqs
=
[
x
[::
-
1
]
for
x
in
args
[
1
:
self
.
n_seqs
+
1
]]
offset
=
0
for
Xt
,
Xt_placeholder
in
zip
(
diff_outputs
[
self
.
n_mit_mot_outs
:],
Xts
):
tmp
=
forced_replace
(
dXt_inps
[
dx
],
Xt
,
Xt_placeholder
)
dXt_inps
[
dx
]
=
tmp
# construct dX_dtm1
dXt_dXtm1s
=
[
x
.
type
()
for
x
in
dXt_inps
[
self
.
n_seqs
:]]
for
dx
,
dXt_dXtm1
in
enumerate
(
dXt_dXtm1s
):
dXt_inps
[
dx
+
self
.
n_seqs
]
+=
dXt_dXtm1
# Construct scan op
# Seqs
outer_inp_seqs
=
[
x
[::
-
1
]
for
x
in
inputs
[
1
:
1
+
self
.
n_seqs
]]
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
):
mintap
=
numpy
.
min
(
self
.
tap_array
[
idx
])
maxtap
=
numpy
.
max
(
self
.
tap_array
[
idx
])
seq
=
scan_outputs
[
offset
+
idx
]
seq
=
outs
[
idx
]
for
k
in
self
.
tap_array
[
idx
]:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
if
maxtap
<
0
:
dim_offset
=
abs
(
maxtap
)
else
:
...
...
@@ -1397,124 +1337,125 @@ class Scan(PureOp):
-
(
maxtap
-
k
+
1
)][::
-
1
]
else
:
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
1
][::
-
1
]
if
getattr
(
seq
,
'name'
,
None
)
is
not
None
:
nw_seq
.
name
=
seq
.
name
+
'[
%
d:]'
%
k
scan_seqs
.
append
(
nw_seq
)
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
seq
=
scan_outputs
[
offset
+
idx
][:
-
1
]
scan_seqs
.
append
(
seq
[::
-
1
])
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
scan_seqs
+=
[
x
[::
-
1
]
for
x
in
g_outs
[
offset
:
offset
+
self
.
n_nit_sot
]]
scan_mit_mot
=
[]
inner_mit_mot
=
[]
scan_mit_mot_outs
=
[]
mit_mot_taps
=
[]
mit_mot_out_slices
=
[]
outer_inp_seqs
.
append
(
nw_seq
)
outer_inp_seqs
+=
[
x
[:
-
1
][::
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
for
x
in
self
.
outer_nitsot_outs
(
dC_douts
):
if
not
isinstance
(
x
.
type
,
DisconnectedType
):
outer_inp_seqs
.
append
(
x
[::
-
1
])
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_mitsot_outs
(
outs
)]
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_nitsot_outs
(
outs
)]
inner_inp_seqs
=
self
.
inner_seqs
(
self_inputs
)
inner_inp_seqs
+=
self
.
inner_mitmot
(
self_inputs
)
inner_inp_seqs
+=
self
.
inner_mitsot
(
self_inputs
)
inner_inp_seqs
+=
self
.
inner_sitsot
(
self_inputs
)
inner_inp_seqs
+=
self
.
inner_nitsot_outs
(
dXtp1_dXts
)
inner_inp_seqs
+=
Xts
# mitmot
outer_inp_mitmot
=
[]
outer_out_mitmot
=
[]
inner_inp_mitmot
=
[]
inner_out_mitmot
=
[]
mitmot_inp_taps
=
[]
mitmot_out_taps
=
[]
out_pos
=
0
ins_pos
=
n_seqs
n_mit_mot_outs
=
0
n_mit_mot_ins
=
0
ins_pos
=
self
.
n_seqs
n_mitmot_outs
=
0
n_mitmot_inps
=
0
for
idx
in
xrange
(
self
.
n_mit_mot
):
scan_mit_mot
.
append
(
g_
outs
[
idx
][::
-
1
])
mit
_mot
_taps
.
append
([])
mit
_mot_out_slice
s
.
append
([])
outer_inp_mitmot
.
append
(
dC_d
outs
[
idx
][::
-
1
])
mit
mot_inp
_taps
.
append
([])
mit
mot_out_tap
s
.
append
([])
for
jdx
in
xrange
(
len
(
self
.
mit_mot_out_slices
[
idx
])):
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
])
mit_mot_taps
[
idx
]
.
append
(
\
-
self
.
mit_mot_out_slices
[
idx
][
jdx
])
n_mit_mot_ins
+=
1
inner_inp_mitmot
.
append
(
dXtp1_dXts
[
out_pos
])
mitmot_inp_taps
[
idx
]
.
append
(
-
self
.
mit_mot_out_slices
[
idx
][
jdx
])
n_mitmot_inps
+=
1
out_pos
+=
1
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx
])):
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
])
scan_mit_mot_outs
.
append
(
\
inner_gfn_outs
[
ins_pos
])
n_mit_mot_ins
+=
1
inner_inp_mitmot
.
append
(
dXt_dXtm1s
[
ins_pos
-
self
.
n_seqs
])
inner_out_mitmot
.
append
(
dXt_inps
[
ins_pos
])
n_mitmot_inps_
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
mit_mot_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mit_mot_out_slices
[
idx
]
.
append
(
\
-
self
.
tap_array
[
idx
][
jdx
])
n_mitmot_outs
+=
1
mitmot_inp_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mitmot_out_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
offset
=
self
.
n_mit_mot
for
idx
in
xrange
(
self
.
n_mit_sot
):
mit
_mot
_taps
.
append
([])
mit
_mot_out_slice
s
.
append
([])
scan_mit_mot
.
append
(
g_
outs
[
idx
+
offset
][::
-
1
])
mit
mot_inp
_taps
.
append
([])
mit
mot_out_tap
s
.
append
([])
outer_inp_mitmot
.
append
(
dC_d
outs
[
idx
+
offset
][::
-
1
])
idx_tap
=
idx
+
self
.
n_mit_mot
inner_inp_mitmot
.
append
(
dXtp1_dXts
[
out_pos
])
out_pos
+=
1
n_mitmot_inps
+=
1
mitmot_inp_taps
[
idx
+
offset
]
.
append
(
0
)
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx_tap
])):
inner_
mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_po
s
])
mit
_mot_taps
[
idx
+
offset
]
.
append
(
\
inner_
inp_mitmot
.
append
(
dXt_dXtm1s
[
ins_pos
-
self
.
n_seq
s
])
mit
mot_inp_taps
[
idx
+
offset
]
.
append
(
-
self
.
tap_array
[
idx_tap
][
jdx
])
mit
_mot_out_slices
[
idx
]
.
append
(
\
mit
mot_out_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx_tap
][
jdx
])
scan_mit_mot_outs
.
append
(
inner_gfn_out
s
[
ins_pos
])
n_mit
_mot_in
s
+=
1
inner_out_mitmot
.
append
(
dXt_inp
s
[
ins_pos
])
n_mit
mot_inp
s
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
])
out_pos
+=
1
n_mit_mot_ins
+=
1
mit_mot_taps
[
idx
+
offset
]
.
append
(
0
)
n_mitmot_outs
+=
1
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
mit_mot_taps
.
append
([
0
,
1
])
mit_mot_out_slices
.
append
([
1
])
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
])
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
])
inner_mit_mot
+=
[
inner_g_outs
[
out_pos
],
prev_inner_gfn_outs
[
ins_pos
]]
n_mit_mot_outs
+=
1
mitmot_inp_taps
.
append
([
0
,
1
])
mitmot_out_taps
.
append
([
1
])
if
not
isinstance
(
dC_douts
[
idx
+
offset
]
.
type
,
DisconnectedType
):
outer_inp_mitmot
.
append
(
dC_douts
[
idx
+
offset
][::
-
1
])
else
:
outer_inp_mitmot
.
append
(
tensor
.
zeros
(
outs
[
idx
+
offset
]
.
shape
,
dtype
=
dXt_inps
[
ins_pos
]
.
dtype
))
inner_out_mitmot
.
append
(
dXt_inps
[
ins_pos
])
inner_inp_mitmot
+=
[
dXtp1_dXts
[
out_pos
],
dXt_dXtm1s
[
ins_pos
-
self
.
n_seqs
]]
n_mitmot_outs
+=
1
out_pos
+=
1
ins_pos
+=
1
n_mit_mot_ins
+=
2
n_nit_sot
=
self
.
n_seqs
scan_nit_sot_outs
=
inner_gfn_outs
[:
self
.
n_seqs
]
n_mitmot_inps
+=
2
if
self
.
truncate_gradient
!=
-
1
:
do_steps
=
tensor
.
minimum
(
arg
s
[
0
],
self
.
truncate_gradient
)
do_steps
=
tensor
.
minimum
(
input
s
[
0
],
self
.
truncate_gradient
)
else
:
do_steps
=
args
[
0
]
offset
=
(
self
.
n_seqs
+
n_ins_mit_sot
+
n_ins_mit_mot
+
self
.
n_sit_sot
)
# Instead of shared outs use sit_sot
n_sitsot_outs
=
len
(
prev_inner_gfn_outs
[
offset
:])
scan_sitsot_ins
=
prev_inner_gfn_outs
[
offset
:]
scan_sitsot_init
=
[]
for
x
in
zeros_like_diff_ins
[
offset
:]:
shapes
=
[
x
.
shape
[
i
]
for
i
in
xrange
(
x
.
ndim
)]
empty
=
tensor
.
zeros
([
do_steps
+
1
]
+
shapes
,
dtype
=
x
.
dtype
)
scan_sitsot_init
.
append
(
empty
)
scan_sitsot_outs
=
inner_gfn_outs
[
offset
:]
tap_array
=
mit_mot_taps
+
[[
-
1
]
for
k
in
do_steps
=
inputs
[
0
]
n_nit_sot
=
self
.
n_seqs
inner_out_nitsot
=
dXt_inps
[:
self
.
n_seqs
]
inner_out_sitsot
=
dXt_inps
[
ins_pos
:]
inner_inp_sitsot
=
dXt_dXtm1s
[
ins_pos
-
self
.
n_seqs
:]
outer_inp_sitsot
=
[
tensor
.
zeros
([
do_steps
+
1
]
+
[
x
.
shape
[
i
]
for
i
in
xrange
(
x
.
ndim
)],
dtype
=
y
.
dtype
)
for
y
,
x
in
zip
(
inner_inp_sitsot
,
self
.
outer_non_seqs
(
inputs
))]
n_sitsot_outs
=
len
(
outer_inp_sitsot
)
new_tap_array
=
mitmot_inp_taps
+
[[
-
1
]
for
k
in
xrange
(
n_sitsot_outs
)]
info
=
{}
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_seqs'
]
=
len
(
outer_inp_seqs
)
info
[
'n_mit_sot'
]
=
0
info
[
'tap_array'
]
=
tap_array
info
[
'tap_array'
]
=
new_
tap_array
info
[
'gpu'
]
=
False
n_mit_mot
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
info
[
'n_mit_mot'
]
=
len
(
outer_inp_mitmot
)
info
[
'n_mit_mot_outs'
]
=
n_mitmot_outs
info
[
'mit_mot_out_slices'
]
=
mitmot_out_taps
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
info
[
'n_sit_sot'
]
=
n_sitsot_outs
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
info
[
'n_shared_outs'
]
=
0
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'as_while'
]
=
self
.
as_while
info
[
'profile'
]
=
self
.
profile
...
...
@@ -1524,47 +1465,30 @@ class Scan(PureOp):
else
:
info
[
'name'
]
=
None
info
[
'mode'
]
=
self
.
mode
n_mit_sot
=
0
n_sit_sot
=
0
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
scan_inputs
=
([
do_steps
]
+
scan_seqs
+
scan_mit_mot
+
scan_sitsot_init
+
old_scan_init
+
[
args
[
0
]
for
x
in
xrange
(
n_nit_sot
)]
+
args
[
offset
:])
offset
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
outer_inputs
=
([
do_steps
]
+
outer_inp_seqs
+
outer_inp_mitmot
+
outer_inp_sitsot
+
[
inputs
[
0
]
for
x
in
xrange
(
n_nit_sot
)]
+
self
.
outer_shared
(
inputs
)
+
self
.
outer_non_seqs
(
inputs
))
inner_other_args
=
self_inputs
[
offset
:]
inner_gfn_ins
=
(
inner_seqs
+
inner_mit_mot
+
scan_sitsot_ins
+
old_scan_shared_ins
+
inner_other_args
)
inner_gfn_outs
=
(
scan_mit_mot_outs
+
scan_sitsot_outs
+
scan_nit_sot_outs
+
old_scan_shared_outs
)
inner_gfn_ins
=
(
inner_inp_seqs
+
inner_inp_mitmot
+
inner_inp_sitsot
+
self
.
inner_shared
(
self_inputs
)
+
self
.
inner_non_seqs
(
self_inputs
))
inner_gfn_outs
=
(
inner_out_mitmot
+
inner_out_sitsot
+
inner_out_nitsot
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
outputs
=
local_op
(
*
scan
_inputs
)
outputs
=
local_op
(
*
outer
_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
# Re-order the gradients correctly
gradients
=
[
grad_undefined
(
self
,
0
,
arg
s
[
0
],
'Number of steps'
)]
gradients
=
[
grad_undefined
(
self
,
0
,
input
s
[
0
],
'Number of steps'
)]
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
...
...
@@ -1576,12 +1500,12 @@ class Scan(PureOp):
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[:
end
]]
start
=
len
(
gradients
)
gradients
+=
[
grad_undefined
(
self
,
x
+
start
,
arg
s
[
x
+
start
],
grad_undefined
(
self
,
x
+
start
,
input
s
[
x
+
start
],
'Shared Variable with update'
)
for
x
in
xrange
(
self
.
n_shared_outs
)]
start
=
len
(
gradients
)
gradients
+=
[
grad_undefined
(
self
,
x
+
start
,
arg
s
[
x
+
start
],
grad_undefined
(
self
,
x
+
start
,
input
s
[
x
+
start
],
'Dimension of memory buffer for output'
)
for
x
in
xrange
(
self
.
n_nit_sot
)]
begin
=
end
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论