提交 c05d32ce authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Pre-emptive fix of a bug. There is no way of producing the bug with the

trunk of Theano, it can only be produced by Arnaud optimization (that fuses scan ops together). Unfortunately due to lack of testing, the optimization is not yet added into the trunk. The bug is that the scan_op generated by the gradient uses the same variables as the original scan, and if you try to fuse these two ops you run into troubles.
上级 c82a180b
...@@ -707,27 +707,44 @@ class Scan(Op): ...@@ -707,27 +707,44 @@ class Scan(Op):
if not( type(scan_outputs) in (list,tuple)): if not( type(scan_outputs) in (list,tuple)):
scan_outputs = [scan_outputs] scan_outputs = [scan_outputs]
# 3. un-group / unzip the inputs # 3. un-group / unzip the inputs
seqs = self.inputs[:self.n_seqs] # Note ! We don't want to use the actual same variable as the ones
# used by the original scan, rather create clones of them
def new_var(x):
nw_x = x.type()
if x.name:
nw_x.name=x.name +'grad_copy'
return nw_x
self_inputs = [new_var(x) for x in self.inputs ]
givens = {}
for new_x, x in zip(self_inputs, self.inputs):
givens[x] = new_x
self_outputs = scan_utils.clone(self.outputs, replace=givens)
seqs = self_inputs[:self.n_seqs]
offset = self.n_seqs offset = self.n_seqs
n_ins_mit_mot = numpy.sum([0] + [ len(self.tap_array[x]) for x n_ins_mit_mot = numpy.sum([0] + [ len(self.tap_array[x]) for x
in xrange(self.n_mit_mot) ]) in xrange(self.n_mit_mot) ])
outs_mit_mot = self.inputs[offset:offset+n_ins_mit_mot] outs_mit_mot = self_inputs[offset:offset+n_ins_mit_mot]
offset += n_ins_mit_mot offset += n_ins_mit_mot
n_ins_mit_sot = numpy.sum([0] + [ len(self.tap_array[x]) for x n_ins_mit_sot = numpy.sum([0] + [ len(self.tap_array[x]) for x
in xrange( self.n_mit_mot in xrange( self.n_mit_mot
, self.n_mit_mot+self.n_mit_sot)]) , self.n_mit_mot+self.n_mit_sot)])
outs_mit_sot = self.inputs[offset:offset+n_ins_mit_sot] outs_mit_sot = self_inputs[offset:offset+n_ins_mit_sot]
offset += n_ins_mit_sot offset += n_ins_mit_sot
outs_sit_sot = self.inputs[offset:offset+self.n_sit_sot] outs_sit_sot = self_inputs[offset:offset+self.n_sit_sot]
offset += self.n_sit_sot offset += self.n_sit_sot
old_scan_shared_ins = self.inputs[offset:offset+self.n_shared_outs] old_scan_shared_ins = self_inputs[offset:offset+self.n_shared_outs]
out_offset = ( self.n_mit_mot_outs out_offset = ( self.n_mit_mot_outs
+ self.n_mit_sot + self.n_mit_sot
+ self.n_nit_sot + self.n_nit_sot
+ self.n_sit_sot ) + self.n_sit_sot )
old_scan_shared_outs = self.outputs[out_offset:] old_scan_shared_outs = self_outputs[out_offset:]
arg_offset = ( 1 arg_offset = ( 1
+ self.n_seqs + self.n_seqs
+ self.n_mit_mot + self.n_mit_mot
...@@ -735,7 +752,7 @@ class Scan(Op): ...@@ -735,7 +752,7 @@ class Scan(Op):
+ self.n_sit_sot) + self.n_sit_sot)
old_scan_init = args[arg_offset: arg_offset+self.n_shared_outs] old_scan_init = args[arg_offset: arg_offset+self.n_shared_outs]
offset += self.n_shared_outs offset += self.n_shared_outs
other_args = self.inputs[offset:] other_args = self_inputs[offset:]
# 4. Collect (possibly) differentiable inputs # 4. Collect (possibly) differentiable inputs
...@@ -758,7 +775,7 @@ class Scan(Op): ...@@ -758,7 +775,7 @@ class Scan(Op):
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot ) + self.n_nit_sot )
clean_outputs = self.outputs[:end] clean_outputs = self_outputs[:end]
g_outs_no_shared = g_outs[:end] g_outs_no_shared = g_outs[:end]
# 7.1. empty lists to hold gradients # 7.1. empty lists to hold gradients
...@@ -1016,7 +1033,7 @@ class Scan(Op): ...@@ -1016,7 +1033,7 @@ class Scan(Op):
+ self.n_sit_sot + self.n_sit_sot
+ self.n_shared_outs ) + self.n_shared_outs )
inner_other_args = self.inputs[offset:] inner_other_args = self_inputs[offset:]
inner_gfn_ins = ( inner_seqs + inner_gfn_ins = ( inner_seqs +
inner_mit_mot + inner_mit_mot +
scan_shared_ins + scan_shared_ins +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论