提交 d07219c7 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

PEP8 fixes

Note sure it makes the file anymore readable, but at least I've tried.
上级 b873707b
......@@ -188,18 +188,21 @@ class Scan(PureOp):
'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)'
'of variable %s (argument number %d)'
' has dtype %s and %d dimension(s), while the result of the'
' inner function for this output has dtype %s and %d '
'dimension(s). This could happen if the inner graph of '
' scan results in an upcast or downcast. Please make '
'sure that you use dtypes consistently')
' has dtype %s and %d dimension(s), while the result '
'of the inner function for this output has dtype %s '
'and %d dimension(s). This could happen if the inner '
'graph of scan results in an upcast or downcast. '
'Please make sure that you use dtypes consistently')
# TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and inputs correspond)
# TODO assert the type(dtype, nbdim of self.inputs and
# inputs correspond)
#assert len(inputs) >= len(self.inputs)
# if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + self.info["n_nit_sot"]
# else:
# assert len(inputs) == len(self.inputs) + 1 + self.info["n_nit_sot"]
#if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + \
# self.info["n_nit_sot"]
#else:
# assert len(inputs) == len(self.inputs) + 1 + \
# self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors
self.vector_seqs = [seq.ndim == 1 for seq in
......@@ -236,26 +239,27 @@ class Scan(PureOp):
self.mitmot_out_taps(),
self.outer_mitmot(inputs))):
for k in xrange(len(itaps)):
if (inner_mitmot[ipos+k].type.dtype !=
if (inner_mitmot[ipos + k].type.dtype !=
outer_mitmot.type.dtype or
inner_mitmot[ipos+k].ndim != outer_mitmot.ndim - 1):
inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
str(inner_mitmot[ipos+k]),
inner_mitmot[ipos+k].type.dtype))
' in scan nomenclature) ',
str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
str(inner_mitmot[ipos + k]),
inner_mitmot[ipos + k].type.dtype))
ipos += len(itaps)
for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos+k].type.dtype !=
outer_mitmot.type.dtype or
inner_mitmot_outs[opos+k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg2 %
if (inner_mitmot_outs[opos + k].type.dtype != \
outer_mitmot.type.dtype or
inner_mitmot_outs[opos + k].ndim != \
outer_mitmot.ndim - 1):
raise ValueError(err_msg2 %
(str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
inner_mitmot_outs[opos+k].type.dtype))
argoffset + idx,
outer_mitmot.type.dtype,
inner_mitmot_outs[opos + k].type.dtype))
opos += len(otaps)
argoffset += len(self.outer_mitmot(inputs))
# Same checks as above but for outputs of type mit_sot
......@@ -266,28 +270,28 @@ class Scan(PureOp):
self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.outputs))):
for k in xrange(len(itaps)):
if (inner_mitsots[ipos+k].type.dtype !=
outer_mitsot.type.dtype or
inner_mitsots[ipos+k].ndim != outer_mitsot.ndim - 1):
if (inner_mitsots[ipos + k].type.dtype != \
outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitsot),
argoffset + idx,
outer_mitsot.type.dtype,
otuer_mitsot.type.ndim,
str(inner_mitsot[ipos+k]),
inner_mitsots[ipos+k].type.dtype,
inner_mitsots[ipos+k].type.ndim))
str(inner_mitsot[ipos + k]),
inner_mitsots[ipos + k].type.dtype,
inner_mitsots[ipos + k].type.ndim))
ipos += len(itaps)
if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or
inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg2 %
raise ValueError(err_msg2 %
(str(outer_mitsot),
argoffset + idx,
outer_mitsot.type.dtype,
outer_mitsot.type.ndim,
inner_mitsot_out.type.dtype,
inner_mitsot_out.type.ndim ))
inner_mitsot_out.type.ndim))
argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot
......@@ -308,13 +312,13 @@ class Scan(PureOp):
inner_sitsot.type.ndim))
if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or
inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg2 %
raise ValueError(err_msg2 %
(str(outer_sitsot),
argoffset + idx,
outer_sitsot.type.dtype,
outer_sitsot.type.ndim,
inner_sitsot_out.type.dtype,
inner_sitsot_out.type.ndim ))
inner_sitsot_out.type.ndim))
argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same
......@@ -352,7 +356,7 @@ class Scan(PureOp):
inner_nonseq.type.ndim != outer_nonseq.type.ndim):
raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s')%
' match its correspondance %s') %
(str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that
......@@ -1120,7 +1124,7 @@ class Scan(PureOp):
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if self.as_while:
scan_outs = [(Shape_i(0)(o),)+x[1:]
scan_outs = [(Shape_i(0)(o),) + x[1:]
for o, x in izip(node.outputs, scan_outs)]
return scan_outs
......@@ -1147,79 +1151,80 @@ class Scan(PureOp):
in xrange(self.n_mit_mot)])
outs_mit_mot = self_inputs[offset:offset + n_ins_mit_mot]
offset += n_ins_mit_mot
n_ins_mit_sot = numpy.sum([0] + [ len(self.tap_array[x]) for x
in xrange( self.n_mit_mot
, self.n_mit_mot+self.n_mit_sot)])
outs_mit_sot = self_inputs[offset:offset+n_ins_mit_sot]
offset += n_ins_mit_sot
outs_sit_sot = self_inputs[offset:offset+self.n_sit_sot]
offset += self.n_sit_sot
old_scan_shared_ins = self_inputs[offset:offset+self.n_shared_outs]
out_offset = ( self.n_mit_mot_outs
+ self.n_mit_sot
+ self.n_nit_sot
+ self.n_sit_sot )
offset += n_ins_mit_mot
n_ins_mit_sot = numpy.sum([0] + [len(self.tap_array[x]) for x
in xrange(self.n_mit_mot,
self.n_mit_mot + self.n_mit_sot)])
outs_mit_sot = self_inputs[offset:offset + n_ins_mit_sot]
offset += n_ins_mit_sot
outs_sit_sot = self_inputs[offset:offset + self.n_sit_sot]
offset += self.n_sit_sot
old_scan_shared_ins = self_inputs[offset:offset + self.n_shared_outs]
out_offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_nit_sot +
self.n_sit_sot)
# shared variables as well as the condition
old_scan_shared_outs = self_outputs[out_offset:]
arg_offset = ( 1
+ self.n_seqs
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot)
old_scan_init = args[arg_offset: arg_offset+self.n_shared_outs]
offset += self.n_shared_outs
other_args = self_inputs[offset:]
old_scan_shared_outs = self_outputs[out_offset:]
arg_offset = (1 +
self.n_seqs +
self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot)
old_scan_init = args[arg_offset: arg_offset + self.n_shared_outs]
offset += self.n_shared_outs
other_args = self_inputs[offset:]
# 4. Collect (possibly) differentiable inputs
diff_inputs = ( seqs +
outs_mit_mot +
outs_mit_sot +
outs_sit_sot +
other_args )
diff_inputs = (seqs +
outs_mit_mot +
outs_mit_sot +
outs_sit_sot +
other_args)
#args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs)
def compute_gradient(y, g_y):
gmp = gradient.grad_sources_inputs(
[(y,g_y)], diff_inputs, False )
return [gmp.get(p, None) for p in diff_inputs ]
[(y, g_y)], diff_inputs, False)
return [gmp.get(p, None) for p in diff_inputs]
# 6. clean the outputs (i.e. remove update rules)
end = ( self.n_mit_mot_outs
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot )
clean_outputs = self_outputs[:end]
end = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot)
clean_outputs = self_outputs[:end]
g_outs_no_shared = g_outs[:end]
# 7.1. empty lists to hold gradients
# List of slices from outputs (used to compute the gradients)
inner_g_outs = []
g_out_slices = []
inner_g_outs = []
g_out_slices = []
# List of outputs of the gradient function
inner_gfn_outs = []
inner_gfn_outs = []
# slices of the input
prev_inner_gfn_outs = []
zeros_like_diff_ins = []
pos = ( self.n_seqs + n_ins_mit_mot + n_ins_mit_sot +
prev_inner_gfn_outs = []
zeros_like_diff_ins = []
pos = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot)
offset = len(args) - len(other_args) - pos
# 7.2. generate variables to represent previous steps of g_outs
for idx,diff_in in enumerate(diff_inputs):
for idx, diff_in in enumerate(diff_inputs):
prev_gfn_out = safe_new(diff_in)
if hasattr(diff_in,'name') and diff_in.name:
prev_gfn_out.name = 'g_prev_'+diff_in.name
if hasattr(diff_in, 'name') and diff_in.name:
prev_gfn_out.name = 'g_prev_' + diff_in.name
else:
prev_gfn_out.name = 'g_prev_'+str(idx)
prev_inner_gfn_outs.append( prev_gfn_out)
prev_gfn_out.name = 'g_prev_' + str(idx)
prev_inner_gfn_outs.append(prev_gfn_out)
if idx < pos:
zeros_like_diff_ins.append(tensor.zeros_like(diff_in))
else:
zeros_like_diff_ins.append(tensor.zeros_like(args[idx+offset]))
zeros_like_diff_ins.append(
tensor.zeros_like(args[idx + offset]))
# 7.3. compute gradients of the inputs given one output
for dx, out in enumerate(clean_outputs):
......@@ -1227,32 +1232,30 @@ class Scan(PureOp):
###
#### I need to clip the gradient HERE !!
if g_outs_no_shared[dx]:
g_out_slices.append(g_outs_no_shared[dx][0])
else:
g_out_slices.append(None)
if getattr(out,'name',None) is not None:
inner_g_out.name = 'g_'+out.name
if getattr(out, 'name', None) is not None:
inner_g_out.name = 'g_' + out.name
else:
inner_g_out.name = 'g_'+str(dx)
inner_g_out.name = 'g_' + str(dx)
inner_g_outs.append(inner_g_out)
_g_out = inner_g_out
grad_outs = compute_gradient(out, _g_out)
if not inner_gfn_outs:
for idx, gfn_out in enumerate(grad_outs):
if idx >= self.n_seqs:
inner_gfn_outs.append( prev_inner_gfn_outs[idx] )
inner_gfn_outs.append(prev_inner_gfn_outs[idx])
else:
inner_gfn_outs.append( None )
inner_gfn_outs.append(None)
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
# (assume their gradient is 0)
for i,(x,y) in enumerate(zip(grad_outs, inner_gfn_outs)):
for i, (x, y) in enumerate(zip(grad_outs, inner_gfn_outs)):
if x and y:
inner_gfn_outs[i] = x+y
inner_gfn_outs[i] = x + y
elif y:
inner_gfn_outs[i] = y
else:
......@@ -1276,28 +1279,27 @@ class Scan(PureOp):
g_outs[i] = theano.tensor.constant(
numpy.array(0, theano.config.floatX))
## 10. Get your sequence in order for the scan:
n_seqs = ( self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot +
self.n_nit_sot )
offset = ( self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot )
inner_seqs = ( seqs +
n_seqs = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot +
self.n_nit_sot)
offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot)
inner_seqs = (seqs +
outs_mit_mot +
outs_mit_sot +
outs_sit_sot +
inner_g_outs[offset:offset+self.n_nit_sot])
inner_g_outs[offset:offset + self.n_nit_sot])
scan_seqs = [ x[::-1] for x in args[1:self.n_seqs + 1]]
scan_seqs = [x[::-1] for x in args[1:self.n_seqs + 1]]
offset = 0
for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx])
seq = scan_outputs[offset+idx]
seq = scan_outputs[offset + idx]
for k in self.tap_array[idx]:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
......@@ -1307,205 +1309,205 @@ class Scan(PureOp):
dim_offset = 0
if maxtap == mintap and maxtap != 0:
nw_seq = seq[:abs(maxtap)]
elif maxtap -k != 0 :
tmp = seq[dim_offset + k - mintap - 1:-(maxtap -k + 1)]
nw_seq = tmp[::-1]
elif maxtap - k != 0:
nw_seq = seq[dim_offset + k - mintap - 1:\
-(maxtap - k + 1)][::-1]
else:
nw_seq = seq[dim_offset + k - mintap - 1: -1][::-1]
if getattr(seq,'name', None) is not None:
nw_seq.name = seq.name + '[%d:]'%k
if getattr(seq, 'name', None) is not None:
nw_seq.name = seq.name + '[%d:]' % k
scan_seqs.append(nw_seq)
offset += self.n_mit_sot
for idx in xrange(self.n_sit_sot):
seq = scan_outputs[offset+idx][:-1]
seq = scan_outputs[offset + idx][:-1]
scan_seqs.append(seq[::-1])
offset = ( self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot )
scan_seqs += [ x[::-1] for x in
g_outs[offset:offset+self.n_nit_sot]]
offset = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot)
scan_seqs += [x[::-1] for x in
g_outs[offset:offset + self.n_nit_sot]]
scan_mit_mot = []
inner_mit_mot = []
scan_mit_mot_outs = []
mit_mot_taps = []
scan_mit_mot = []
inner_mit_mot = []
scan_mit_mot_outs = []
mit_mot_taps = []
mit_mot_out_slices = []
out_pos = 0
ins_pos = n_seqs
n_mit_mot_outs = 0
n_mit_mot_ins = 0
ins_pos = self.n_seqs
out_pos = 0
ins_pos = n_seqs
n_mit_mot_outs = 0
n_mit_mot_ins = 0
ins_pos = self.n_seqs
for idx in xrange(self.n_mit_mot):
scan_mit_mot.append( g_outs[idx][::-1] )
scan_mit_mot.append(g_outs[idx][::-1])
mit_mot_taps.append([])
mit_mot_out_slices.append([])
for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_mit_mot.append( inner_g_outs[out_pos] )
mit_mot_taps[idx].append(
inner_mit_mot.append(inner_g_outs[out_pos])
mit_mot_taps[idx].append(\
-self.mit_mot_out_slices[idx][jdx])
n_mit_mot_ins += 1
out_pos += 1
out_pos += 1
for jdx in xrange(len(self.tap_array[idx])):
inner_mit_mot.append( prev_inner_gfn_outs[ins_pos] )
scan_mit_mot_outs.append(
inner_gfn_outs[ ins_pos] )
n_mit_mot_ins += 1
ins_pos += 1
inner_mit_mot.append(prev_inner_gfn_outs[ins_pos])
scan_mit_mot_outs.append(\
inner_gfn_outs[ins_pos])
n_mit_mot_ins += 1
ins_pos += 1
n_mit_mot_outs += 1
mit_mot_taps[idx].append( -self.tap_array[idx][jdx])
mit_mot_out_slices[idx].append(
-self.tap_array[idx][jdx] )
mit_mot_taps[idx].append(-self.tap_array[idx][jdx])
mit_mot_out_slices[idx].append(\
-self.tap_array[idx][jdx])
offset = self.n_mit_mot
for idx in xrange(self.n_mit_sot):
mit_mot_taps.append([])
mit_mot_out_slices.append([])
scan_mit_mot.append( g_outs[idx + offset][::-1] )
scan_mit_mot.append(g_outs[idx + offset][::-1])
idx_tap = idx + self.n_mit_mot
for jdx in xrange(len(self.tap_array[idx_tap])):
inner_mit_mot.append( prev_inner_gfn_outs[ins_pos] )
mit_mot_taps[idx+offset].append(
-self.tap_array[idx_tap][jdx] )
mit_mot_out_slices[idx].append(
-self.tap_array[idx_tap][jdx] )
scan_mit_mot_outs.append(inner_gfn_outs[ ins_pos] )
n_mit_mot_ins += 1
ins_pos += 1
inner_mit_mot.append(prev_inner_gfn_outs[ins_pos])
mit_mot_taps[idx + offset].append(\
-self.tap_array[idx_tap][jdx])
mit_mot_out_slices[idx].append(\
-self.tap_array[idx_tap][jdx])
scan_mit_mot_outs.append(inner_gfn_outs[ins_pos])
n_mit_mot_ins += 1
ins_pos += 1
n_mit_mot_outs += 1
inner_mit_mot.append( inner_g_outs[out_pos] )
inner_mit_mot.append(inner_g_outs[out_pos])
out_pos += 1
n_mit_mot_ins += 1
mit_mot_taps[idx+offset].append( 0 )
mit_mot_taps[idx + offset].append(0)
offset += self.n_mit_sot
for idx in xrange(self.n_sit_sot):
mit_mot_taps.append([0,1])
mit_mot_taps.append([0, 1])
mit_mot_out_slices.append([1])
scan_mit_mot.append( g_outs[idx + offset][::-1] )
scan_mit_mot_outs.append(inner_gfn_outs[ ins_pos ])
inner_mit_mot += [ inner_g_outs[out_pos]
, prev_inner_gfn_outs[ins_pos] ]
scan_mit_mot.append(g_outs[idx + offset][::-1])
scan_mit_mot_outs.append(inner_gfn_outs[ins_pos])
inner_mit_mot += [inner_g_outs[out_pos],
prev_inner_gfn_outs[ins_pos]]
n_mit_mot_outs += 1
out_pos += 1
ins_pos += 1
n_mit_mot_ins += 2
out_pos += 1
ins_pos += 1
n_mit_mot_ins += 2
n_nit_sot = self.n_seqs
scan_nit_sot_outs = inner_gfn_outs[:self.n_seqs]
if self.truncate_gradient != -1 :
if self.truncate_gradient != -1:
do_steps = tensor.minimum(args[0], self.truncate_gradient)
else:
do_steps = args[0]
offset = ( self.n_seqs
+ n_ins_mit_sot
+ n_ins_mit_mot
+ self.n_sit_sot )
offset = (self.n_seqs +
n_ins_mit_sot +
n_ins_mit_mot +
self.n_sit_sot)
# Instead of shared outs use sit_sot
n_sitsot_outs = len(prev_inner_gfn_outs[offset:])
scan_sitsot_ins = prev_inner_gfn_outs[offset:]
n_sitsot_outs = len(prev_inner_gfn_outs[offset:])
scan_sitsot_ins = prev_inner_gfn_outs[offset:]
scan_sitsot_init = []
for x in zeros_like_diff_ins[offset:]:
shapes = [x.shape[i] for i in xrange(x.ndim)]
empty = tensor.zeros([do_steps +1]+shapes,
empty = tensor.zeros([do_steps + 1] + shapes,
dtype=x.dtype)
scan_sitsot_init.append(empty)
scan_sitsot_outs = inner_gfn_outs[offset:]
tap_array = mit_mot_taps + [[-1] for k in
tap_array = mit_mot_taps + [[-1] for k in
xrange(n_sitsot_outs)]
info = {}
info['n_seqs'] = n_seqs
info['n_mit_sot'] = 0
info['tap_array'] = tap_array
info['gpu'] = False
n_mit_mot = ( self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot )
info['n_mit_mot'] = n_mit_mot
info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices
info['truncate_gradient'] = self.truncate_gradient
info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = self.n_shared_outs
info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while
info['profile'] = self.profile
info['n_seqs'] = n_seqs
info['n_mit_sot'] = 0
info['tap_array'] = tap_array
info['gpu'] = False
n_mit_mot = (self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot)
info['n_mit_mot'] = n_mit_mot
info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices
info['truncate_gradient'] = self.truncate_gradient
info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = self.n_shared_outs
info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while
info['profile'] = self.profile
if self.name:
info['name'] = 'grad_of_' + self.name
info['name'] = 'grad_of_' + self.name
else:
info['name'] = None
info['mode'] = self.mode
info['inplace'] = False
n_mit_sot = 0
n_sit_sot = 0
offset = ( 1
+ self.n_seqs
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs )
scan_inputs = ( [do_steps] +
scan_seqs +
scan_mit_mot +
scan_sitsot_init +
old_scan_init +
[ args[0] for x in xrange(n_nit_sot) ] +
args[offset:] )
offset = ( self.n_seqs
+ n_ins_mit_mot
+ n_ins_mit_sot
+ self.n_sit_sot
+ self.n_shared_outs )
info['mode'] = self.mode
info['inplace'] = False
n_mit_sot = 0
n_sit_sot = 0
offset = (1 +
self.n_seqs +
self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot +
self.n_shared_outs)
scan_inputs = ([do_steps] +
scan_seqs +
scan_mit_mot +
scan_sitsot_init +
old_scan_init +
[args[0] for x in xrange(n_nit_sot)] +
args[offset:])
offset = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot +
self.n_shared_outs)
inner_other_args = self_inputs[offset:]
inner_gfn_ins = ( inner_seqs +
inner_mit_mot +
scan_sitsot_ins +
old_scan_shared_ins +
inner_other_args )
inner_gfn_outs = ( scan_mit_mot_outs +
scan_sitsot_outs +
scan_nit_sot_outs +
old_scan_shared_outs )
local_op = Scan( inner_gfn_ins, inner_gfn_outs, info )
inner_gfn_ins = (inner_seqs +
inner_mit_mot +
scan_sitsot_ins +
old_scan_shared_ins +
inner_other_args)
inner_gfn_outs = (scan_mit_mot_outs +
scan_sitsot_outs +
scan_nit_sot_outs +
old_scan_shared_outs)
local_op = Scan(inner_gfn_ins, inner_gfn_outs, info)
outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple):
outputs = [ outputs ]
outputs = [outputs]
# Re-order the gradients correctly
gradients = [None]
offset = ( self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ n_sitsot_outs)
gradients += [ x[::-1] for x in outputs[offset:offset+self.n_seqs]]
offset = (self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot +
n_sitsot_outs)
gradients += [x[::-1] for x in outputs[offset:offset + self.n_seqs]]
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
gradients += [ x[::-1] for x in outputs[:end]]
gradients += [ None for x in xrange(self.n_shared_outs)]
gradients += [ None for x in xrange(self.n_nit_sot) ]
gradients += [x[::-1] for x in outputs[:end]]
gradients += [None for x in xrange(self.n_shared_outs)]
gradients += [None for x in xrange(self.n_nit_sot)]
begin = end
end = begin + n_sitsot_outs
end = begin + n_sitsot_outs
gradients += [x[-1] for x in outputs[begin:end]]
return gradients
def R_op(self, inputs, eval_points):
# Step 0. Don't work on the orignal tensor variables
rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs,'_rop')
self_inputs = rval[0]
self.outputs, '_rop')
self_inputs = rval[0]
self_outputs = rval[1]
# Step 1. Compute the R_op of the inner function
inner_eval_points = [scan_utils.safe_new(x,'_evalpoint') for x in self_inputs]
inner_eval_points = [scan_utils.safe_new(x, '_evalpoint')
for x in self_inputs]
if self.as_while:
rop_self_outputs = self_outputs[:-1]
else:
......@@ -1524,82 +1526,82 @@ class Scan(PureOp):
# evan point for the number of nit_sot which I think should just be
# ignored (?)
info = {}
info['n_seqs'] = self.n_seqs*2
info['n_mit_sot'] = self.n_mit_sot*2
info['n_sit_sot'] = self.n_sit_sot*2
info['n_mit_mot'] = self.n_mit_mot*2
info['n_nit_sot'] = self.n_nit_sot*2
info['n_shared_outs'] = self.n_shared_outs*2
info['gpu'] = False
info['as_while'] = self.as_while
info['n_seqs'] = self.n_seqs * 2
info['n_mit_sot'] = self.n_mit_sot * 2
info['n_sit_sot'] = self.n_sit_sot * 2
info['n_mit_mot'] = self.n_mit_mot * 2
info['n_nit_sot'] = self.n_nit_sot * 2
info['n_shared_outs'] = self.n_shared_outs * 2
info['gpu'] = False
info['as_while'] = self.as_while
info['profile'] = self.profile
info['truncate_gradient'] = self.truncate_gradient
if self.name:
info['name'] = 'rop_of_'+self.name
info['name'] = 'rop_of_' + self.name
else:
info['name'] = None
info['mode'] = self.mode
info['inplace'] = False
info['mit_mot_out_slices'] = self.mit_mot_out_slices*2
info['mit_mot_out_slices'] = self.mit_mot_out_slices * 2
new_tap_array = []
b = 0
e = self.n_mit_mot
new_tap_array += self.tap_array[b:e]*2
new_tap_array += self.tap_array[b:e] * 2
b = e
e += self.n_mit_sot
new_tap_array += self.tap_array[b:e]*2
new_tap_array += self.tap_array[b:e] * 2
b = e
e += self.n_sit_sot
new_tap_array += self.tap_array[b:e]*2
new_tap_array += self.tap_array[b:e] * 2
info['tap_array'] = new_tap_array
# Sequences ...
b = 1
b = 1
ib = 0
e = 1 + self.n_seqs
e = 1 + self.n_seqs
ie = self.n_seqs
scan_seqs = inputs[b:e] + eval_points[b:e]
scan_seqs = inputs[b:e] + eval_points[b:e]
inner_seqs = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_MOT sequences ...
b = e
e = e + self.n_mit_mot
b = e
e = e + self.n_mit_mot
ib = ie
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[:self.n_mit_mot]]))
scan_mit_mot = inputs[b:e] + eval_points[b:e]
scan_mit_mot = inputs[b:e] + eval_points[b:e]
inner_mit_mot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_SOT sequences ...
b = e
e = e + self.n_mit_sot
b = e
e = e + self.n_mit_sot
ib = ie
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot:self.n_mit_mot+self.n_mit_sot]]))
scan_mit_sot = inputs[b:e] + eval_points[b:e]
self.tap_array[self.n_mit_mot:\
self.n_mit_mot + self.n_mit_sot]]))
scan_mit_sot = inputs[b:e] + eval_points[b:e]
inner_mit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
#SIT_SOT sequences ...
b = e
e = e + self.n_sit_sot
b = e
e = e + self.n_sit_sot
ib = ie
ie = ie + self.n_sit_sot
scan_sit_sot = inputs[b:e] + eval_points[b:e]
scan_sit_sot = inputs[b:e] + eval_points[b:e]
inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
#Shared outs ...
b = e
e = e + self.n_shared_outs
b = e
e = e + self.n_shared_outs
ib = ie
ie = ie + self.n_shared_outs
scan_shared = inputs[b:e] + eval_points[b:e]
scan_shared = inputs[b:e] + eval_points[b:e]
inner_shared = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# NIT_SOT sequences
b = e
e = e + self.n_nit_sot
scan_nit_sot = inputs[b:e]*2
scan_nit_sot = inputs[b:e] * 2
# All other arguments
scan_other = inputs[e:] + eval_points[e:]
......@@ -1625,13 +1627,13 @@ class Scan(PureOp):
e = e + self.n_shared_outs
inner_out_shared = self_outputs[b:e] + rop_outs[b:e]
inner_ins = ( inner_seqs +
inner_ins = (inner_seqs +
inner_mit_mot +
inner_mit_sot +
inner_sit_sot +
inner_shared +
inner_other )
inner_outs = ( inner_out_mit_mot +
inner_other)
inner_outs = (inner_out_mit_mot +
inner_out_mit_sot +
inner_out_sit_sot +
inner_out_nit_sot +
......@@ -1639,35 +1641,35 @@ class Scan(PureOp):
if self.as_while:
inner_outs += [self_outputs[-1]]
scan_inputs = ( [inputs[0]] +
scan_seqs +
scan_mit_mot +
scan_mit_sot +
scan_sit_sot +
scan_shared +
scan_nit_sot +
scan_other)
local_op = Scan( inner_ins, inner_outs, info )
scan_inputs = ([inputs[0]] +
scan_seqs +
scan_mit_mot +
scan_mit_sot +
scan_sit_sot +
scan_shared +
scan_nit_sot +
scan_other)
local_op = Scan(inner_ins, inner_outs, info)
outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple):
outputs = [ outputs ]
outputs = [outputs]
# Select only the result of the R_op results
final_outs = []
b = self.n_mit_mot
e = self.n_mit_mot*2
e = self.n_mit_mot * 2
final_outs += outputs[b:e]
b = e + self.n_mit_sot
e = e + self.n_mit_sot*2
e = e + self.n_mit_sot * 2
final_outs += outputs[b:e]
b = e + self.n_sit_sot
e = e + self.n_sit_sot*2
e = e + self.n_sit_sot * 2
final_outs += outputs[b:e]
b = e + self.n_nit_sot
e = e + self.n_nit_sot*2
e = e + self.n_nit_sot * 2
final_outs += outputs[b:e]
b = e + self.n_shared_outs
e = e + self.n_shared_outs*2
e = e + self.n_shared_outs * 2
final_outs += outputs[b:e]
return final_outs
......@@ -1678,26 +1680,36 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time, apply_cimpl, message, outputs_size,
other_time):
# Scan overhead profile
if any([isinstance(node.op, Scan) and v>0 for (_,node),v in
if any([isinstance(node.op, Scan) and v > 0 for (_, node), v in
apply_time.items()]):
print
print 'Scan overhead:'
print '<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(% scan op time)> <sub scan op time(% scan op time)> <node>'
print ('<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(% scan op time)> <sub scan '
'op time(% scan op time)> <node>')
total_super_scan_time = 0
total_scan_fct_time = 0
total_scan_op_time = 0
for (_,node),v in apply_time.items():
for (_, node), v in apply_time.items():
if isinstance(node.op, Scan):
if v> 0:
if v > 0:
scan_fct_time = node.op.mode_instance.fn_time
scan_op_time = node.op.mode_instance.local_time
total_super_scan_time += v
total_scan_fct_time += scan_fct_time
total_scan_op_time += scan_op_time
print ' %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%'%(
v, scan_fct_time, scan_op_time, scan_fct_time/v*100,
scan_op_time/v*100), node
print ' %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%' % (
v,
scan_fct_time,
scan_op_time,
scan_fct_time / v * 100,
scan_op_time / v * 100), node
else:
print ' The node took 0s, so we can not compute the overhead', node
print ' total %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%'%(
total_super_scan_time, total_scan_fct_time, total_scan_op_time, total_scan_fct_time/total_super_scan_time*100, total_scan_op_time/total_super_scan_time*100)
print (' The node took 0s, so we can not '
'compute the overhead'), node
print ' total %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%' % (
total_super_scan_time,
total_scan_fct_time,
total_scan_op_time,
total_scan_fct_time / total_super_scan_time * 100,
total_scan_op_time / total_super_scan_time * 100)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论