提交 b392cc03 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

replace grad to be sit_sot sequences

The main bug was gradients where represented as shared variables. Now we represent them as sit_sot sequences to which only the last step is used (hence the savemem optimization does the memory clean up). The advantage is that gradients with respect to sitsot are well defined.
上级 7c19aaaf
...@@ -1405,11 +1405,18 @@ class Scan(PureOp): ...@@ -1405,11 +1405,18 @@ class Scan(PureOp):
+ n_ins_mit_sot + n_ins_mit_sot
+ n_ins_mit_mot + n_ins_mit_mot
+ self.n_sit_sot ) + self.n_sit_sot )
n_shared_outs = len(prev_inner_gfn_outs[offset:]) # Instead of shared outs use sit_sot
scan_shared_ins = prev_inner_gfn_outs[offset:] n_sitsot_outs = len(prev_inner_gfn_outs[offset:])
scan_shared_init = zeros_like_diff_ins[offset:] scan_sitsot_ins = prev_inner_gfn_outs[offset:]
scan_shared_outs = inner_gfn_outs[offset:] scan_sitsot_init = []
tap_array = mit_mot_taps for x in zeros_like_diff_ins[offset:]:
shapes = [x.shape[i] for i in xrange(x.ndim)]
empty = tensor.zeros([do_steps +1]+shapes,
dtype=x.dtype)
scan_sitsot_init.append(empty)
scan_sitsot_outs = inner_gfn_outs[offset:]
tap_array = mit_mot_taps + [[-1] for k in
xrange(n_sitsot_outs)]
info = {} info = {}
info['n_seqs'] = n_seqs info['n_seqs'] = n_seqs
info['n_mit_sot'] = 0 info['n_mit_sot'] = 0
...@@ -1422,8 +1429,8 @@ class Scan(PureOp): ...@@ -1422,8 +1429,8 @@ class Scan(PureOp):
info['n_mit_mot_outs'] = n_mit_mot_outs info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices info['mit_mot_out_slices'] = mit_mot_out_slices
info['truncate_gradient'] = self.truncate_gradient info['truncate_gradient'] = self.truncate_gradient
info['n_sit_sot'] = 0 info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = n_shared_outs + self.n_shared_outs info['n_shared_outs'] = self.n_shared_outs
info['n_nit_sot'] = n_nit_sot info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while info['as_while'] = self.as_while
info['profile'] = self.profile info['profile'] = self.profile
...@@ -1447,7 +1454,7 @@ class Scan(PureOp): ...@@ -1447,7 +1454,7 @@ class Scan(PureOp):
scan_inputs = ( [do_steps] + scan_inputs = ( [do_steps] +
scan_seqs + scan_seqs +
scan_mit_mot + scan_mit_mot +
scan_shared_init + scan_sitsot_init +
old_scan_init + old_scan_init +
[ args[0] for x in xrange(n_nit_sot) ] + [ args[0] for x in xrange(n_nit_sot) ] +
args[offset:] ) args[offset:] )
...@@ -1461,14 +1468,13 @@ class Scan(PureOp): ...@@ -1461,14 +1468,13 @@ class Scan(PureOp):
inner_other_args = self_inputs[offset:] inner_other_args = self_inputs[offset:]
inner_gfn_ins = ( inner_seqs + inner_gfn_ins = ( inner_seqs +
inner_mit_mot + inner_mit_mot +
scan_shared_ins + scan_sitsot_ins +
old_scan_shared_ins + old_scan_shared_ins +
inner_other_args ) inner_other_args )
inner_gfn_outs = ( scan_mit_mot_outs + inner_gfn_outs = ( scan_mit_mot_outs +
scan_sitsot_outs +
scan_nit_sot_outs + scan_nit_sot_outs +
scan_shared_outs +
old_scan_shared_outs ) old_scan_shared_outs )
local_op = Scan( inner_gfn_ins, inner_gfn_outs, info ) local_op = Scan( inner_gfn_ins, inner_gfn_outs, info )
outputs = local_op(*scan_inputs) outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple): if type(outputs) not in (list, tuple):
...@@ -1478,17 +1484,18 @@ class Scan(PureOp): ...@@ -1478,17 +1484,18 @@ class Scan(PureOp):
offset = ( self.n_mit_mot offset = ( self.n_mit_mot
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot ) + self.n_sit_sot
+ n_sitsot_outs)
gradients += [ x[::-1] for x in outputs[offset:offset+self.n_seqs]] gradients += [ x[::-1] for x in outputs[offset:offset+self.n_seqs]]
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
gradients += [ x[::-1] for x in outputs[:end]] gradients += [ x[::-1] for x in outputs[:end]]
gradients += [ None for x in xrange(self.n_shared_outs)] gradients += [ None for x in xrange(self.n_shared_outs)]
gradients += [ None for x in xrange(self.n_nit_sot) ] gradients += [ None for x in xrange(self.n_nit_sot) ]
begin = end + self.n_seqs begin = end
end = begin + n_shared_outs end = begin + n_sitsot_outs
gradients += outputs[begin:end] gradients += [x[-1] for x in outputs[begin:end]]
return gradients return gradients
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论