提交 b392cc03 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

replace grad to be sit_sot sequences

The main bug was gradients where represented as shared variables. Now we represent them as sit_sot sequences to which only the last step is used (hence the savemem optimization does the memory clean up). The advantage is that gradients with respect to sitsot are well defined.
上级 7c19aaaf
......@@ -1405,11 +1405,18 @@ class Scan(PureOp):
+ n_ins_mit_sot
+ n_ins_mit_mot
+ self.n_sit_sot )
n_shared_outs = len(prev_inner_gfn_outs[offset:])
scan_shared_ins = prev_inner_gfn_outs[offset:]
scan_shared_init = zeros_like_diff_ins[offset:]
scan_shared_outs = inner_gfn_outs[offset:]
tap_array = mit_mot_taps
# Instead of shared outs use sit_sot
n_sitsot_outs = len(prev_inner_gfn_outs[offset:])
scan_sitsot_ins = prev_inner_gfn_outs[offset:]
scan_sitsot_init = []
for x in zeros_like_diff_ins[offset:]:
shapes = [x.shape[i] for i in xrange(x.ndim)]
empty = tensor.zeros([do_steps +1]+shapes,
dtype=x.dtype)
scan_sitsot_init.append(empty)
scan_sitsot_outs = inner_gfn_outs[offset:]
tap_array = mit_mot_taps + [[-1] for k in
xrange(n_sitsot_outs)]
info = {}
info['n_seqs'] = n_seqs
info['n_mit_sot'] = 0
......@@ -1422,8 +1429,8 @@ class Scan(PureOp):
info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices
info['truncate_gradient'] = self.truncate_gradient
info['n_sit_sot'] = 0
info['n_shared_outs'] = n_shared_outs + self.n_shared_outs
info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = self.n_shared_outs
info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while
info['profile'] = self.profile
......@@ -1447,7 +1454,7 @@ class Scan(PureOp):
scan_inputs = ( [do_steps] +
scan_seqs +
scan_mit_mot +
scan_shared_init +
scan_sitsot_init +
old_scan_init +
[ args[0] for x in xrange(n_nit_sot) ] +
args[offset:] )
......@@ -1461,14 +1468,13 @@ class Scan(PureOp):
inner_other_args = self_inputs[offset:]
inner_gfn_ins = ( inner_seqs +
inner_mit_mot +
scan_shared_ins +
scan_sitsot_ins +
old_scan_shared_ins +
inner_other_args )
inner_gfn_outs = ( scan_mit_mot_outs +
scan_sitsot_outs +
scan_nit_sot_outs +
scan_shared_outs +
old_scan_shared_outs )
local_op = Scan( inner_gfn_ins, inner_gfn_outs, info )
outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple):
......@@ -1478,17 +1484,18 @@ class Scan(PureOp):
offset = ( self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot )
+ self.n_sit_sot
+ n_sitsot_outs)
gradients += [ x[::-1] for x in outputs[offset:offset+self.n_seqs]]
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
gradients += [ x[::-1] for x in outputs[:end]]
gradients += [ None for x in xrange(self.n_shared_outs)]
gradients += [ None for x in xrange(self.n_nit_sot) ]
begin = end + self.n_seqs
begin = end
end = begin + n_shared_outs
gradients += outputs[begin:end]
end = begin + n_sitsot_outs
gradients += [x[-1] for x in outputs[begin:end]]
return gradients
def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论