提交 2dc3ce06 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

added R_op for scan

上级 a5d9a4b6
......@@ -1132,6 +1132,180 @@ class Scan(Op):
gradients += outputs[begin:end]
return gradients
def R_op(self, inputs, eval_points):
# Step 0. Don't work on the orignal tensor variables
rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs,'_rop')
self_inputs = rval[0]
self_outputs = rval[1]
# Step 1. Compute the R_op of the inner function
inner_eval_points = [scan_utils.safe_new(x,'_evalpoint') for x in self_inputs]
if self.as_while:
rop_self_outputs = self_outputs[:-1]
else:
rop_self_outputs = self_outputs
rop_outs = tensor.Rop(rop_self_outputs, self_inputs, inner_eval_points)
if type(rop_outs) not in (list, tuple):
rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan
# When doing the R-op of scan, you end up having double of each type of
# input, because for each sequence you need also its eval point, for
# each mit_mot, mit_sot, sit_sot or other type of inputs the same.
# Interestingly enough, all these types of eval points behave the same
# way as the input to which they correspond
# The only exception is the eval point for the number of sequences, and
# evan point for the number of nit_sot which I think should just be
# ignored (?)
info = {}
info['n_seqs'] = self.n_seqs*2
info['n_mit_sot'] = self.n_mit_sot*2
info['n_sit_sot'] = self.n_sit_sot*2
info['n_mit_mot'] = self.n_mit_mot*2
info['n_nit_sot'] = self.n_nit_sot*2
info['n_shared_outs'] = self.n_shared_outs*2
info['gpu'] = False
info['as_while'] = self.as_while
info['profile'] = self.profile
info['truncate_gradient'] = self.truncate_gradient
if self.name:
info['name'] = 'rop_of_'+self.name
else:
info['name'] = None
info['mode'] = self.mode
info['inplace'] = False
info['mit_mot_out_slices'] = self.mit_mot_out_slices*2
new_tap_array = []
b = 0
e = self.n_mit_mot
new_tap_array += self.tap_array[b:e]*2
b = e
e += self.n_mit_sot
new_tap_array += self.tap_array[b:e]*2
b = e
e += self.n_sit_sot
new_tap_array += self.tap_array[b:e]*2
info['tap_array'] = new_tap_array
# Sequences ...
b = 1
ib = 0
e = 1 + self.n_seqs
ie = self.n_seqs
scan_seqs = inputs[b:e] + eval_points[b:e]
inner_seqs = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_MOT sequences ...
b = e
e = e + self.n_mit_mot
ib = ie
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[:self.n_mit_mot]]))
scan_mit_mot = inputs[b:e] + eval_points[b:e]
inner_mit_mot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_SOT sequences ...
b = e
e = e + self.n_mit_sot
ib = ie
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot:self.n_mit_mot+self.n_mit_sot]]))
scan_mit_sot = inputs[b:e] + eval_points[b:e]
inner_mit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
#SIT_SOT sequences ...
b = e
e = e + self.n_sit_sot
ib = ie
ie = ie + self.n_sit_sot
scan_sit_sot = inputs[b:e] + eval_points[b:e]
inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
#Shared outs ...
b = e
e = e + self.n_shared_outs
ib = ie
ie = ie + self.n_shared_outs
scan_shared = inputs[b:e] + eval_points[b:e]
inner_shared = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# NIT_SOT sequences
b = e
e = e + self.n_nit_sot
scan_nit_sot = inputs[b:e]*2
# All other arguments
scan_other = inputs[e:] + eval_points[e:]
inner_other = self_inputs[ie:] + inner_eval_points[ie:]
# Outputs
n_mit_mot_outs = int(numpy.sum([len(x) for x in
self.mit_mot_out_slices]))
info['n_mit_mot_outs'] = n_mit_mot_outs
b = 0
e = n_mit_mot_outs
inner_out_mit_mot = self_outputs[b:e] + rop_outs[b:e]
b = e
e = e + self.n_mit_sot
inner_out_mit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e
e = e + self.n_sit_sot
inner_out_sit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e
e = e + self.n_nit_sot
inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e
e = e + self.n_shared_outs
inner_out_shared = self_outputs[b:e] + rop_outs[b:e]
inner_ins = ( inner_seqs +
inner_mit_mot +
inner_mit_sot +
inner_sit_sot +
inner_shared +
inner_other )
inner_outs = ( inner_out_mit_mot +
inner_out_mit_sot +
inner_out_sit_sot +
inner_out_nit_sot +
inner_out_shared)
if self.as_while:
inner_outs += [self_outputs[-1]]
scan_inputs = ( [inputs[0]] +
scan_seqs +
scan_mit_mot +
scan_mit_sot +
scan_sit_sot +
scan_shared +
scan_nit_sot +
scan_other)
local_op = Scan( inner_ins, inner_outs, info )
outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple):
outputs = [ outputs ]
# Select only the result of the R_op results
final_outs = []
b = self.n_mit_mot
e = self.n_mit_mot*2
final_outs += outputs[b:e]
b = e + self.n_mit_sot
e = e + self.n_mit_sot*2
final_outs += outputs[b:e]
b = e + self.n_sit_sot
e = e + self.n_sit_sot*2
final_outs += outputs[b:e]
b = e + self.n_nit_sot
e = e + self.n_nit_sot*2
final_outs += outputs[b:e]
b = e + self.n_shared_outs
e = e + self.n_shared_outs*2
final_outs += outputs[b:e]
return final_outs
@theano.compile.profilemode.register_profiler_printer
def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论