提交 d07219c7 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

PEP8 fixes

Note sure it makes the file anymore readable, but at least I've tried.
上级 b873707b
...@@ -188,18 +188,21 @@ class Scan(PureOp): ...@@ -188,18 +188,21 @@ class Scan(PureOp):
'following error has been encountered: The ' 'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)' 'initial state (outputs_info in scan nomenclature)'
'of variable %s (argument number %d)' 'of variable %s (argument number %d)'
' has dtype %s and %d dimension(s), while the result of the' ' has dtype %s and %d dimension(s), while the result '
' inner function for this output has dtype %s and %d ' 'of the inner function for this output has dtype %s '
'dimension(s). This could happen if the inner graph of ' 'and %d dimension(s). This could happen if the inner '
' scan results in an upcast or downcast. Please make ' 'graph of scan results in an upcast or downcast. '
'sure that you use dtypes consistently') 'Please make sure that you use dtypes consistently')
# TODO make the assert exact # TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and inputs correspond) # TODO assert the type(dtype, nbdim of self.inputs and
# inputs correspond)
#assert len(inputs) >= len(self.inputs) #assert len(inputs) >= len(self.inputs)
# if self.info['as_while']: #if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + self.info["n_nit_sot"] # assert len(inputs) == len(self.inputs) + 2 + \
# else: # self.info["n_nit_sot"]
# assert len(inputs) == len(self.inputs) + 1 + self.info["n_nit_sot"] #else:
# assert len(inputs) == len(self.inputs) + 1 + \
# self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors # Flags that indicate which inputs are vectors
self.vector_seqs = [seq.ndim == 1 for seq in self.vector_seqs = [seq.ndim == 1 for seq in
...@@ -236,26 +239,27 @@ class Scan(PureOp): ...@@ -236,26 +239,27 @@ class Scan(PureOp):
self.mitmot_out_taps(), self.mitmot_out_taps(),
self.outer_mitmot(inputs))): self.outer_mitmot(inputs))):
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitmot[ipos+k].type.dtype != if (inner_mitmot[ipos + k].type.dtype !=
outer_mitmot.type.dtype or outer_mitmot.type.dtype or
inner_mitmot[ipos+k].ndim != outer_mitmot.ndim - 1): inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitmot), str(outer_mitmot),
argoffset + idx, argoffset + idx,
outer_mitmot.type.dtype, outer_mitmot.type.dtype,
str(inner_mitmot[ipos+k]), str(inner_mitmot[ipos + k]),
inner_mitmot[ipos+k].type.dtype)) inner_mitmot[ipos + k].type.dtype))
ipos += len(itaps) ipos += len(itaps)
for k in xrange(len(otaps)): for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos+k].type.dtype != if (inner_mitmot_outs[opos + k].type.dtype != \
outer_mitmot.type.dtype or outer_mitmot.type.dtype or
inner_mitmot_outs[opos+k].ndim != outer_mitmot.ndim - 1): inner_mitmot_outs[opos + k].ndim != \
raise ValueError(err_msg2 % outer_mitmot.ndim - 1):
raise ValueError(err_msg2 %
(str(outer_mitmot), (str(outer_mitmot),
argoffset + idx, argoffset + idx,
outer_mitmot.type.dtype, outer_mitmot.type.dtype,
inner_mitmot_outs[opos+k].type.dtype)) inner_mitmot_outs[opos + k].type.dtype))
opos += len(otaps) opos += len(otaps)
argoffset += len(self.outer_mitmot(inputs)) argoffset += len(self.outer_mitmot(inputs))
# Same checks as above but for outputs of type mit_sot # Same checks as above but for outputs of type mit_sot
...@@ -266,28 +270,28 @@ class Scan(PureOp): ...@@ -266,28 +270,28 @@ class Scan(PureOp):
self.outer_mitsot(inputs), self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.outputs))): self.inner_mitsot_outs(self.outputs))):
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitsots[ipos+k].type.dtype != if (inner_mitsots[ipos + k].type.dtype != \
outer_mitsot.type.dtype or outer_mitsot.type.dtype or
inner_mitsots[ipos+k].ndim != outer_mitsot.ndim - 1): inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitsot), str(outer_mitsot),
argoffset + idx, argoffset + idx,
outer_mitsot.type.dtype, outer_mitsot.type.dtype,
otuer_mitsot.type.ndim, otuer_mitsot.type.ndim,
str(inner_mitsot[ipos+k]), str(inner_mitsot[ipos + k]),
inner_mitsots[ipos+k].type.dtype, inner_mitsots[ipos + k].type.dtype,
inner_mitsots[ipos+k].type.ndim)) inner_mitsots[ipos + k].type.ndim))
ipos += len(itaps) ipos += len(itaps)
if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or
inner_mitsot_out.ndim != outer_mitsot.ndim - 1): inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_mitsot), (str(outer_mitsot),
argoffset + idx, argoffset + idx,
outer_mitsot.type.dtype, outer_mitsot.type.dtype,
outer_mitsot.type.ndim, outer_mitsot.type.ndim,
inner_mitsot_out.type.dtype, inner_mitsot_out.type.dtype,
inner_mitsot_out.type.ndim )) inner_mitsot_out.type.ndim))
argoffset += len(self.outer_mitsot(inputs)) argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot # Same checks as above but for outputs of type sit_sot
...@@ -308,13 +312,13 @@ class Scan(PureOp): ...@@ -308,13 +312,13 @@ class Scan(PureOp):
inner_sitsot.type.ndim)) inner_sitsot.type.ndim))
if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or
inner_sitsot_out.ndim != outer_sitsot.ndim - 1): inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_sitsot), (str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.type.dtype, outer_sitsot.type.dtype,
outer_sitsot.type.ndim, outer_sitsot.type.ndim,
inner_sitsot_out.type.dtype, inner_sitsot_out.type.dtype,
inner_sitsot_out.type.ndim )) inner_sitsot_out.type.ndim))
argoffset += len(self.outer_sitsot(inputs)) argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
...@@ -352,7 +356,7 @@ class Scan(PureOp): ...@@ -352,7 +356,7 @@ class Scan(PureOp):
inner_nonseq.type.ndim != outer_nonseq.type.ndim): inner_nonseq.type.ndim != outer_nonseq.type.ndim):
raise ValueError(('Argument %s given to scan node does not' raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s')% ' match its correspondance %s') %
(str(outer_nonseq), str(inner_nonseq))) (str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs): for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that # For every nit_sot input we get as input a int/uint that
...@@ -1120,7 +1124,7 @@ class Scan(PureOp): ...@@ -1120,7 +1124,7 @@ class Scan(PureOp):
# if we are dealing with a repeat-until, then we do not know the # if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i # leading dimension so we replace it for every entry with Shape_i
if self.as_while: if self.as_while:
scan_outs = [(Shape_i(0)(o),)+x[1:] scan_outs = [(Shape_i(0)(o),) + x[1:]
for o, x in izip(node.outputs, scan_outs)] for o, x in izip(node.outputs, scan_outs)]
return scan_outs return scan_outs
...@@ -1147,79 +1151,80 @@ class Scan(PureOp): ...@@ -1147,79 +1151,80 @@ class Scan(PureOp):
in xrange(self.n_mit_mot)]) in xrange(self.n_mit_mot)])
outs_mit_mot = self_inputs[offset:offset + n_ins_mit_mot] outs_mit_mot = self_inputs[offset:offset + n_ins_mit_mot]
offset += n_ins_mit_mot offset += n_ins_mit_mot
n_ins_mit_sot = numpy.sum([0] + [ len(self.tap_array[x]) for x n_ins_mit_sot = numpy.sum([0] + [len(self.tap_array[x]) for x
in xrange( self.n_mit_mot in xrange(self.n_mit_mot,
, self.n_mit_mot+self.n_mit_sot)]) self.n_mit_mot + self.n_mit_sot)])
outs_mit_sot = self_inputs[offset:offset+n_ins_mit_sot] outs_mit_sot = self_inputs[offset:offset + n_ins_mit_sot]
offset += n_ins_mit_sot offset += n_ins_mit_sot
outs_sit_sot = self_inputs[offset:offset+self.n_sit_sot] outs_sit_sot = self_inputs[offset:offset + self.n_sit_sot]
offset += self.n_sit_sot offset += self.n_sit_sot
old_scan_shared_ins = self_inputs[offset:offset+self.n_shared_outs] old_scan_shared_ins = self_inputs[offset:offset + self.n_shared_outs]
out_offset = ( self.n_mit_mot_outs out_offset = (self.n_mit_mot_outs +
+ self.n_mit_sot self.n_mit_sot +
+ self.n_nit_sot self.n_nit_sot +
+ self.n_sit_sot ) self.n_sit_sot)
# shared variables as well as the condition # shared variables as well as the condition
old_scan_shared_outs = self_outputs[out_offset:] old_scan_shared_outs = self_outputs[out_offset:]
arg_offset = ( 1 arg_offset = (1 +
+ self.n_seqs self.n_seqs +
+ self.n_mit_mot self.n_mit_mot +
+ self.n_mit_sot self.n_mit_sot +
+ self.n_sit_sot) self.n_sit_sot)
old_scan_init = args[arg_offset: arg_offset+self.n_shared_outs] old_scan_init = args[arg_offset: arg_offset + self.n_shared_outs]
offset += self.n_shared_outs offset += self.n_shared_outs
other_args = self_inputs[offset:] other_args = self_inputs[offset:]
# 4. Collect (possibly) differentiable inputs # 4. Collect (possibly) differentiable inputs
diff_inputs = ( seqs + diff_inputs = (seqs +
outs_mit_mot + outs_mit_mot +
outs_mit_sot + outs_mit_sot +
outs_sit_sot + outs_sit_sot +
other_args ) other_args)
#args[-len(other_args):] ) #args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over # 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs) # the gradients with respect to all outputs)
def compute_gradient(y, g_y): def compute_gradient(y, g_y):
gmp = gradient.grad_sources_inputs( gmp = gradient.grad_sources_inputs(
[(y,g_y)], diff_inputs, False ) [(y, g_y)], diff_inputs, False)
return [gmp.get(p, None) for p in diff_inputs ] return [gmp.get(p, None) for p in diff_inputs]
# 6. clean the outputs (i.e. remove update rules) # 6. clean the outputs (i.e. remove update rules)
end = ( self.n_mit_mot_outs end = (self.n_mit_mot_outs +
+ self.n_mit_sot self.n_mit_sot +
+ self.n_sit_sot self.n_sit_sot +
+ self.n_nit_sot ) self.n_nit_sot)
clean_outputs = self_outputs[:end] clean_outputs = self_outputs[:end]
g_outs_no_shared = g_outs[:end] g_outs_no_shared = g_outs[:end]
# 7.1. empty lists to hold gradients # 7.1. empty lists to hold gradients
# List of slices from outputs (used to compute the gradients) # List of slices from outputs (used to compute the gradients)
inner_g_outs = [] inner_g_outs = []
g_out_slices = [] g_out_slices = []
# List of outputs of the gradient function # List of outputs of the gradient function
inner_gfn_outs = [] inner_gfn_outs = []
# slices of the input # slices of the input
prev_inner_gfn_outs = [] prev_inner_gfn_outs = []
zeros_like_diff_ins = [] zeros_like_diff_ins = []
pos = ( self.n_seqs + n_ins_mit_mot + n_ins_mit_sot + pos = (self.n_seqs +
n_ins_mit_mot +
n_ins_mit_sot +
self.n_sit_sot) self.n_sit_sot)
offset = len(args) - len(other_args) - pos offset = len(args) - len(other_args) - pos
# 7.2. generate variables to represent previous steps of g_outs # 7.2. generate variables to represent previous steps of g_outs
for idx,diff_in in enumerate(diff_inputs): for idx, diff_in in enumerate(diff_inputs):
prev_gfn_out = safe_new(diff_in) prev_gfn_out = safe_new(diff_in)
if hasattr(diff_in,'name') and diff_in.name: if hasattr(diff_in, 'name') and diff_in.name:
prev_gfn_out.name = 'g_prev_'+diff_in.name prev_gfn_out.name = 'g_prev_' + diff_in.name
else: else:
prev_gfn_out.name = 'g_prev_'+str(idx) prev_gfn_out.name = 'g_prev_' + str(idx)
prev_inner_gfn_outs.append( prev_gfn_out) prev_inner_gfn_outs.append(prev_gfn_out)
if idx < pos: if idx < pos:
zeros_like_diff_ins.append(tensor.zeros_like(diff_in)) zeros_like_diff_ins.append(tensor.zeros_like(diff_in))
else: else:
zeros_like_diff_ins.append(tensor.zeros_like(args[idx+offset])) zeros_like_diff_ins.append(
tensor.zeros_like(args[idx + offset]))
# 7.3. compute gradients of the inputs given one output # 7.3. compute gradients of the inputs given one output
for dx, out in enumerate(clean_outputs): for dx, out in enumerate(clean_outputs):
...@@ -1227,32 +1232,30 @@ class Scan(PureOp): ...@@ -1227,32 +1232,30 @@ class Scan(PureOp):
### ###
#### I need to clip the gradient HERE !! #### I need to clip the gradient HERE !!
if g_outs_no_shared[dx]: if g_outs_no_shared[dx]:
g_out_slices.append(g_outs_no_shared[dx][0]) g_out_slices.append(g_outs_no_shared[dx][0])
else: else:
g_out_slices.append(None) g_out_slices.append(None)
if getattr(out,'name',None) is not None: if getattr(out, 'name', None) is not None:
inner_g_out.name = 'g_'+out.name inner_g_out.name = 'g_' + out.name
else: else:
inner_g_out.name = 'g_'+str(dx) inner_g_out.name = 'g_' + str(dx)
inner_g_outs.append(inner_g_out) inner_g_outs.append(inner_g_out)
_g_out = inner_g_out _g_out = inner_g_out
grad_outs = compute_gradient(out, _g_out) grad_outs = compute_gradient(out, _g_out)
if not inner_gfn_outs: if not inner_gfn_outs:
for idx, gfn_out in enumerate(grad_outs): for idx, gfn_out in enumerate(grad_outs):
if idx >= self.n_seqs: if idx >= self.n_seqs:
inner_gfn_outs.append( prev_inner_gfn_outs[idx] ) inner_gfn_outs.append(prev_inner_gfn_outs[idx])
else: else:
inner_gfn_outs.append( None ) inner_gfn_outs.append(None)
# 7.4 Sum the gradients # 7.4 Sum the gradients
# safety check, some of this inputs might still not be # safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix # differentiable, for those we don't add them to the mix
# (assume their gradient is 0) # (assume their gradient is 0)
for i,(x,y) in enumerate(zip(grad_outs, inner_gfn_outs)): for i, (x, y) in enumerate(zip(grad_outs, inner_gfn_outs)):
if x and y: if x and y:
inner_gfn_outs[i] = x+y inner_gfn_outs[i] = x + y
elif y: elif y:
inner_gfn_outs[i] = y inner_gfn_outs[i] = y
else: else:
...@@ -1276,28 +1279,27 @@ class Scan(PureOp): ...@@ -1276,28 +1279,27 @@ class Scan(PureOp):
g_outs[i] = theano.tensor.constant( g_outs[i] = theano.tensor.constant(
numpy.array(0, theano.config.floatX)) numpy.array(0, theano.config.floatX))
## 10. Get your sequence in order for the scan: ## 10. Get your sequence in order for the scan:
n_seqs = ( self.n_seqs + n_seqs = (self.n_seqs +
n_ins_mit_mot + n_ins_mit_mot +
n_ins_mit_sot + n_ins_mit_sot +
self.n_sit_sot + self.n_sit_sot +
self.n_nit_sot ) self.n_nit_sot)
offset = ( self.n_mit_mot_outs + offset = (self.n_mit_mot_outs +
self.n_mit_sot + self.n_mit_sot +
self.n_sit_sot ) self.n_sit_sot)
inner_seqs = ( seqs + inner_seqs = (seqs +
outs_mit_mot + outs_mit_mot +
outs_mit_sot + outs_mit_sot +
outs_sit_sot + outs_sit_sot +
inner_g_outs[offset:offset+self.n_nit_sot]) inner_g_outs[offset:offset + self.n_nit_sot])
scan_seqs = [ x[::-1] for x in args[1:self.n_seqs + 1]] scan_seqs = [x[::-1] for x in args[1:self.n_seqs + 1]]
offset = 0 offset = 0
for idx in xrange(self.n_mit_mot + self.n_mit_sot): for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx]) mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx]) maxtap = numpy.max(self.tap_array[idx])
seq = scan_outputs[offset+idx] seq = scan_outputs[offset + idx]
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
# We cut the sequence such that seq[i] to correspond to # We cut the sequence such that seq[i] to correspond to
# seq[i-k] # seq[i-k]
...@@ -1307,205 +1309,205 @@ class Scan(PureOp): ...@@ -1307,205 +1309,205 @@ class Scan(PureOp):
dim_offset = 0 dim_offset = 0
if maxtap == mintap and maxtap != 0: if maxtap == mintap and maxtap != 0:
nw_seq = seq[:abs(maxtap)] nw_seq = seq[:abs(maxtap)]
elif maxtap -k != 0 : elif maxtap - k != 0:
tmp = seq[dim_offset + k - mintap - 1:-(maxtap -k + 1)] nw_seq = seq[dim_offset + k - mintap - 1:\
nw_seq = tmp[::-1] -(maxtap - k + 1)][::-1]
else: else:
nw_seq = seq[dim_offset + k - mintap - 1: -1][::-1] nw_seq = seq[dim_offset + k - mintap - 1: -1][::-1]
if getattr(seq,'name', None) is not None: if getattr(seq, 'name', None) is not None:
nw_seq.name = seq.name + '[%d:]'%k nw_seq.name = seq.name + '[%d:]' % k
scan_seqs.append(nw_seq) scan_seqs.append(nw_seq)
offset += self.n_mit_sot offset += self.n_mit_sot
for idx in xrange(self.n_sit_sot): for idx in xrange(self.n_sit_sot):
seq = scan_outputs[offset+idx][:-1] seq = scan_outputs[offset + idx][:-1]
scan_seqs.append(seq[::-1]) scan_seqs.append(seq[::-1])
offset = ( self.n_mit_mot_outs + offset = (self.n_mit_mot_outs +
self.n_mit_sot + self.n_mit_sot +
self.n_sit_sot ) self.n_sit_sot)
scan_seqs += [ x[::-1] for x in scan_seqs += [x[::-1] for x in
g_outs[offset:offset+self.n_nit_sot]] g_outs[offset:offset + self.n_nit_sot]]
scan_mit_mot = [] scan_mit_mot = []
inner_mit_mot = [] inner_mit_mot = []
scan_mit_mot_outs = [] scan_mit_mot_outs = []
mit_mot_taps = [] mit_mot_taps = []
mit_mot_out_slices = [] mit_mot_out_slices = []
out_pos = 0 out_pos = 0
ins_pos = n_seqs ins_pos = n_seqs
n_mit_mot_outs = 0 n_mit_mot_outs = 0
n_mit_mot_ins = 0 n_mit_mot_ins = 0
ins_pos = self.n_seqs ins_pos = self.n_seqs
for idx in xrange(self.n_mit_mot): for idx in xrange(self.n_mit_mot):
scan_mit_mot.append( g_outs[idx][::-1] ) scan_mit_mot.append(g_outs[idx][::-1])
mit_mot_taps.append([]) mit_mot_taps.append([])
mit_mot_out_slices.append([]) mit_mot_out_slices.append([])
for jdx in xrange(len(self.mit_mot_out_slices[idx])): for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_mit_mot.append( inner_g_outs[out_pos] ) inner_mit_mot.append(inner_g_outs[out_pos])
mit_mot_taps[idx].append( mit_mot_taps[idx].append(\
-self.mit_mot_out_slices[idx][jdx]) -self.mit_mot_out_slices[idx][jdx])
n_mit_mot_ins += 1 n_mit_mot_ins += 1
out_pos += 1 out_pos += 1
for jdx in xrange(len(self.tap_array[idx])): for jdx in xrange(len(self.tap_array[idx])):
inner_mit_mot.append( prev_inner_gfn_outs[ins_pos] ) inner_mit_mot.append(prev_inner_gfn_outs[ins_pos])
scan_mit_mot_outs.append( scan_mit_mot_outs.append(\
inner_gfn_outs[ ins_pos] ) inner_gfn_outs[ins_pos])
n_mit_mot_ins += 1 n_mit_mot_ins += 1
ins_pos += 1 ins_pos += 1
n_mit_mot_outs += 1 n_mit_mot_outs += 1
mit_mot_taps[idx].append( -self.tap_array[idx][jdx]) mit_mot_taps[idx].append(-self.tap_array[idx][jdx])
mit_mot_out_slices[idx].append( mit_mot_out_slices[idx].append(\
-self.tap_array[idx][jdx] ) -self.tap_array[idx][jdx])
offset = self.n_mit_mot offset = self.n_mit_mot
for idx in xrange(self.n_mit_sot): for idx in xrange(self.n_mit_sot):
mit_mot_taps.append([]) mit_mot_taps.append([])
mit_mot_out_slices.append([]) mit_mot_out_slices.append([])
scan_mit_mot.append( g_outs[idx + offset][::-1] ) scan_mit_mot.append(g_outs[idx + offset][::-1])
idx_tap = idx + self.n_mit_mot idx_tap = idx + self.n_mit_mot
for jdx in xrange(len(self.tap_array[idx_tap])): for jdx in xrange(len(self.tap_array[idx_tap])):
inner_mit_mot.append( prev_inner_gfn_outs[ins_pos] ) inner_mit_mot.append(prev_inner_gfn_outs[ins_pos])
mit_mot_taps[idx+offset].append( mit_mot_taps[idx + offset].append(\
-self.tap_array[idx_tap][jdx] ) -self.tap_array[idx_tap][jdx])
mit_mot_out_slices[idx].append( mit_mot_out_slices[idx].append(\
-self.tap_array[idx_tap][jdx] ) -self.tap_array[idx_tap][jdx])
scan_mit_mot_outs.append(inner_gfn_outs[ ins_pos] ) scan_mit_mot_outs.append(inner_gfn_outs[ins_pos])
n_mit_mot_ins += 1 n_mit_mot_ins += 1
ins_pos += 1 ins_pos += 1
n_mit_mot_outs += 1 n_mit_mot_outs += 1
inner_mit_mot.append( inner_g_outs[out_pos] ) inner_mit_mot.append(inner_g_outs[out_pos])
out_pos += 1 out_pos += 1
n_mit_mot_ins += 1 n_mit_mot_ins += 1
mit_mot_taps[idx+offset].append( 0 ) mit_mot_taps[idx + offset].append(0)
offset += self.n_mit_sot offset += self.n_mit_sot
for idx in xrange(self.n_sit_sot): for idx in xrange(self.n_sit_sot):
mit_mot_taps.append([0,1]) mit_mot_taps.append([0, 1])
mit_mot_out_slices.append([1]) mit_mot_out_slices.append([1])
scan_mit_mot.append( g_outs[idx + offset][::-1] ) scan_mit_mot.append(g_outs[idx + offset][::-1])
scan_mit_mot_outs.append(inner_gfn_outs[ ins_pos ]) scan_mit_mot_outs.append(inner_gfn_outs[ins_pos])
inner_mit_mot += [ inner_g_outs[out_pos] inner_mit_mot += [inner_g_outs[out_pos],
, prev_inner_gfn_outs[ins_pos] ] prev_inner_gfn_outs[ins_pos]]
n_mit_mot_outs += 1 n_mit_mot_outs += 1
out_pos += 1 out_pos += 1
ins_pos += 1 ins_pos += 1
n_mit_mot_ins += 2 n_mit_mot_ins += 2
n_nit_sot = self.n_seqs n_nit_sot = self.n_seqs
scan_nit_sot_outs = inner_gfn_outs[:self.n_seqs] scan_nit_sot_outs = inner_gfn_outs[:self.n_seqs]
if self.truncate_gradient != -1 : if self.truncate_gradient != -1:
do_steps = tensor.minimum(args[0], self.truncate_gradient) do_steps = tensor.minimum(args[0], self.truncate_gradient)
else: else:
do_steps = args[0] do_steps = args[0]
offset = ( self.n_seqs offset = (self.n_seqs +
+ n_ins_mit_sot n_ins_mit_sot +
+ n_ins_mit_mot n_ins_mit_mot +
+ self.n_sit_sot ) self.n_sit_sot)
# Instead of shared outs use sit_sot # Instead of shared outs use sit_sot
n_sitsot_outs = len(prev_inner_gfn_outs[offset:]) n_sitsot_outs = len(prev_inner_gfn_outs[offset:])
scan_sitsot_ins = prev_inner_gfn_outs[offset:] scan_sitsot_ins = prev_inner_gfn_outs[offset:]
scan_sitsot_init = [] scan_sitsot_init = []
for x in zeros_like_diff_ins[offset:]: for x in zeros_like_diff_ins[offset:]:
shapes = [x.shape[i] for i in xrange(x.ndim)] shapes = [x.shape[i] for i in xrange(x.ndim)]
empty = tensor.zeros([do_steps +1]+shapes, empty = tensor.zeros([do_steps + 1] + shapes,
dtype=x.dtype) dtype=x.dtype)
scan_sitsot_init.append(empty) scan_sitsot_init.append(empty)
scan_sitsot_outs = inner_gfn_outs[offset:] scan_sitsot_outs = inner_gfn_outs[offset:]
tap_array = mit_mot_taps + [[-1] for k in tap_array = mit_mot_taps + [[-1] for k in
xrange(n_sitsot_outs)] xrange(n_sitsot_outs)]
info = {} info = {}
info['n_seqs'] = n_seqs info['n_seqs'] = n_seqs
info['n_mit_sot'] = 0 info['n_mit_sot'] = 0
info['tap_array'] = tap_array info['tap_array'] = tap_array
info['gpu'] = False info['gpu'] = False
n_mit_mot = ( self.n_mit_mot n_mit_mot = (self.n_mit_mot +
+ self.n_mit_sot self.n_mit_sot +
+ self.n_sit_sot ) self.n_sit_sot)
info['n_mit_mot'] = n_mit_mot info['n_mit_mot'] = n_mit_mot
info['n_mit_mot_outs'] = n_mit_mot_outs info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices info['mit_mot_out_slices'] = mit_mot_out_slices
info['truncate_gradient'] = self.truncate_gradient info['truncate_gradient'] = self.truncate_gradient
info['n_sit_sot'] = n_sitsot_outs info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = self.n_shared_outs info['n_shared_outs'] = self.n_shared_outs
info['n_nit_sot'] = n_nit_sot info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while info['as_while'] = self.as_while
info['profile'] = self.profile info['profile'] = self.profile
if self.name: if self.name:
info['name'] = 'grad_of_' + self.name info['name'] = 'grad_of_' + self.name
else: else:
info['name'] = None info['name'] = None
info['mode'] = self.mode info['mode'] = self.mode
info['inplace'] = False info['inplace'] = False
n_mit_sot = 0 n_mit_sot = 0
n_sit_sot = 0 n_sit_sot = 0
offset = ( 1 offset = (1 +
+ self.n_seqs self.n_seqs +
+ self.n_mit_mot self.n_mit_mot +
+ self.n_mit_sot self.n_mit_sot +
+ self.n_sit_sot self.n_sit_sot +
+ self.n_nit_sot self.n_nit_sot +
+ self.n_shared_outs ) self.n_shared_outs)
scan_inputs = ( [do_steps] + scan_inputs = ([do_steps] +
scan_seqs + scan_seqs +
scan_mit_mot + scan_mit_mot +
scan_sitsot_init + scan_sitsot_init +
old_scan_init + old_scan_init +
[ args[0] for x in xrange(n_nit_sot) ] + [args[0] for x in xrange(n_nit_sot)] +
args[offset:] ) args[offset:])
offset = ( self.n_seqs offset = (self.n_seqs +
+ n_ins_mit_mot n_ins_mit_mot +
+ n_ins_mit_sot n_ins_mit_sot +
+ self.n_sit_sot self.n_sit_sot +
+ self.n_shared_outs ) self.n_shared_outs)
inner_other_args = self_inputs[offset:] inner_other_args = self_inputs[offset:]
inner_gfn_ins = ( inner_seqs + inner_gfn_ins = (inner_seqs +
inner_mit_mot + inner_mit_mot +
scan_sitsot_ins + scan_sitsot_ins +
old_scan_shared_ins + old_scan_shared_ins +
inner_other_args ) inner_other_args)
inner_gfn_outs = ( scan_mit_mot_outs + inner_gfn_outs = (scan_mit_mot_outs +
scan_sitsot_outs + scan_sitsot_outs +
scan_nit_sot_outs + scan_nit_sot_outs +
old_scan_shared_outs ) old_scan_shared_outs)
local_op = Scan( inner_gfn_ins, inner_gfn_outs, info ) local_op = Scan(inner_gfn_ins, inner_gfn_outs, info)
outputs = local_op(*scan_inputs) outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple): if type(outputs) not in (list, tuple):
outputs = [ outputs ] outputs = [outputs]
# Re-order the gradients correctly # Re-order the gradients correctly
gradients = [None] gradients = [None]
offset = ( self.n_mit_mot offset = (self.n_mit_mot +
+ self.n_mit_sot self.n_mit_sot +
+ self.n_sit_sot self.n_sit_sot +
+ n_sitsot_outs) n_sitsot_outs)
gradients += [ x[::-1] for x in outputs[offset:offset+self.n_seqs]] gradients += [x[::-1] for x in outputs[offset:offset + self.n_seqs]]
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
gradients += [ x[::-1] for x in outputs[:end]] gradients += [x[::-1] for x in outputs[:end]]
gradients += [ None for x in xrange(self.n_shared_outs)] gradients += [None for x in xrange(self.n_shared_outs)]
gradients += [ None for x in xrange(self.n_nit_sot) ] gradients += [None for x in xrange(self.n_nit_sot)]
begin = end begin = end
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
gradients += [x[-1] for x in outputs[begin:end]] gradients += [x[-1] for x in outputs[begin:end]]
return gradients return gradients
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# Step 0. Don't work on the orignal tensor variables # Step 0. Don't work on the orignal tensor variables
rval = scan_utils.reconstruct_graph(self.inputs, rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs,'_rop') self.outputs, '_rop')
self_inputs = rval[0] self_inputs = rval[0]
self_outputs = rval[1] self_outputs = rval[1]
# Step 1. Compute the R_op of the inner function # Step 1. Compute the R_op of the inner function
inner_eval_points = [scan_utils.safe_new(x,'_evalpoint') for x in self_inputs] inner_eval_points = [scan_utils.safe_new(x, '_evalpoint')
for x in self_inputs]
if self.as_while: if self.as_while:
rop_self_outputs = self_outputs[:-1] rop_self_outputs = self_outputs[:-1]
else: else:
...@@ -1524,82 +1526,82 @@ class Scan(PureOp): ...@@ -1524,82 +1526,82 @@ class Scan(PureOp):
# evan point for the number of nit_sot which I think should just be # evan point for the number of nit_sot which I think should just be
# ignored (?) # ignored (?)
info = {} info = {}
info['n_seqs'] = self.n_seqs*2 info['n_seqs'] = self.n_seqs * 2
info['n_mit_sot'] = self.n_mit_sot*2 info['n_mit_sot'] = self.n_mit_sot * 2
info['n_sit_sot'] = self.n_sit_sot*2 info['n_sit_sot'] = self.n_sit_sot * 2
info['n_mit_mot'] = self.n_mit_mot*2 info['n_mit_mot'] = self.n_mit_mot * 2
info['n_nit_sot'] = self.n_nit_sot*2 info['n_nit_sot'] = self.n_nit_sot * 2
info['n_shared_outs'] = self.n_shared_outs*2 info['n_shared_outs'] = self.n_shared_outs * 2
info['gpu'] = False info['gpu'] = False
info['as_while'] = self.as_while info['as_while'] = self.as_while
info['profile'] = self.profile info['profile'] = self.profile
info['truncate_gradient'] = self.truncate_gradient info['truncate_gradient'] = self.truncate_gradient
if self.name: if self.name:
info['name'] = 'rop_of_'+self.name info['name'] = 'rop_of_' + self.name
else: else:
info['name'] = None info['name'] = None
info['mode'] = self.mode info['mode'] = self.mode
info['inplace'] = False info['inplace'] = False
info['mit_mot_out_slices'] = self.mit_mot_out_slices*2 info['mit_mot_out_slices'] = self.mit_mot_out_slices * 2
new_tap_array = [] new_tap_array = []
b = 0 b = 0
e = self.n_mit_mot e = self.n_mit_mot
new_tap_array += self.tap_array[b:e]*2 new_tap_array += self.tap_array[b:e] * 2
b = e b = e
e += self.n_mit_sot e += self.n_mit_sot
new_tap_array += self.tap_array[b:e]*2 new_tap_array += self.tap_array[b:e] * 2
b = e b = e
e += self.n_sit_sot e += self.n_sit_sot
new_tap_array += self.tap_array[b:e]*2 new_tap_array += self.tap_array[b:e] * 2
info['tap_array'] = new_tap_array info['tap_array'] = new_tap_array
# Sequences ... # Sequences ...
b = 1 b = 1
ib = 0 ib = 0
e = 1 + self.n_seqs e = 1 + self.n_seqs
ie = self.n_seqs ie = self.n_seqs
scan_seqs = inputs[b:e] + eval_points[b:e] scan_seqs = inputs[b:e] + eval_points[b:e]
inner_seqs = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_seqs = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_MOT sequences ... # MIT_MOT sequences ...
b = e b = e
e = e + self.n_mit_mot e = e + self.n_mit_mot
ib = ie ib = ie
ie = ie + int(numpy.sum([len(x) for x in ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[:self.n_mit_mot]])) self.tap_array[:self.n_mit_mot]]))
scan_mit_mot = inputs[b:e] + eval_points[b:e] scan_mit_mot = inputs[b:e] + eval_points[b:e]
inner_mit_mot = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_mit_mot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_SOT sequences ... # MIT_SOT sequences ...
b = e b = e
e = e + self.n_mit_sot e = e + self.n_mit_sot
ib = ie ib = ie
ie = ie + int(numpy.sum([len(x) for x in ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot:self.n_mit_mot+self.n_mit_sot]])) self.tap_array[self.n_mit_mot:\
scan_mit_sot = inputs[b:e] + eval_points[b:e] self.n_mit_mot + self.n_mit_sot]]))
scan_mit_sot = inputs[b:e] + eval_points[b:e]
inner_mit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_mit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
#SIT_SOT sequences ... #SIT_SOT sequences ...
b = e b = e
e = e + self.n_sit_sot e = e + self.n_sit_sot
ib = ie ib = ie
ie = ie + self.n_sit_sot ie = ie + self.n_sit_sot
scan_sit_sot = inputs[b:e] + eval_points[b:e] scan_sit_sot = inputs[b:e] + eval_points[b:e]
inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
#Shared outs ... #Shared outs ...
b = e b = e
e = e + self.n_shared_outs e = e + self.n_shared_outs
ib = ie ib = ie
ie = ie + self.n_shared_outs ie = ie + self.n_shared_outs
scan_shared = inputs[b:e] + eval_points[b:e] scan_shared = inputs[b:e] + eval_points[b:e]
inner_shared = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_shared = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# NIT_SOT sequences # NIT_SOT sequences
b = e b = e
e = e + self.n_nit_sot e = e + self.n_nit_sot
scan_nit_sot = inputs[b:e]*2 scan_nit_sot = inputs[b:e] * 2
# All other arguments # All other arguments
scan_other = inputs[e:] + eval_points[e:] scan_other = inputs[e:] + eval_points[e:]
...@@ -1625,13 +1627,13 @@ class Scan(PureOp): ...@@ -1625,13 +1627,13 @@ class Scan(PureOp):
e = e + self.n_shared_outs e = e + self.n_shared_outs
inner_out_shared = self_outputs[b:e] + rop_outs[b:e] inner_out_shared = self_outputs[b:e] + rop_outs[b:e]
inner_ins = ( inner_seqs + inner_ins = (inner_seqs +
inner_mit_mot + inner_mit_mot +
inner_mit_sot + inner_mit_sot +
inner_sit_sot + inner_sit_sot +
inner_shared + inner_shared +
inner_other ) inner_other)
inner_outs = ( inner_out_mit_mot + inner_outs = (inner_out_mit_mot +
inner_out_mit_sot + inner_out_mit_sot +
inner_out_sit_sot + inner_out_sit_sot +
inner_out_nit_sot + inner_out_nit_sot +
...@@ -1639,35 +1641,35 @@ class Scan(PureOp): ...@@ -1639,35 +1641,35 @@ class Scan(PureOp):
if self.as_while: if self.as_while:
inner_outs += [self_outputs[-1]] inner_outs += [self_outputs[-1]]
scan_inputs = ( [inputs[0]] + scan_inputs = ([inputs[0]] +
scan_seqs + scan_seqs +
scan_mit_mot + scan_mit_mot +
scan_mit_sot + scan_mit_sot +
scan_sit_sot + scan_sit_sot +
scan_shared + scan_shared +
scan_nit_sot + scan_nit_sot +
scan_other) scan_other)
local_op = Scan( inner_ins, inner_outs, info ) local_op = Scan(inner_ins, inner_outs, info)
outputs = local_op(*scan_inputs) outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple): if type(outputs) not in (list, tuple):
outputs = [ outputs ] outputs = [outputs]
# Select only the result of the R_op results # Select only the result of the R_op results
final_outs = [] final_outs = []
b = self.n_mit_mot b = self.n_mit_mot
e = self.n_mit_mot*2 e = self.n_mit_mot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
b = e + self.n_mit_sot b = e + self.n_mit_sot
e = e + self.n_mit_sot*2 e = e + self.n_mit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
b = e + self.n_sit_sot b = e + self.n_sit_sot
e = e + self.n_sit_sot*2 e = e + self.n_sit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
b = e + self.n_nit_sot b = e + self.n_nit_sot
e = e + self.n_nit_sot*2 e = e + self.n_nit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
b = e + self.n_shared_outs b = e + self.n_shared_outs
e = e + self.n_shared_outs*2 e = e + self.n_shared_outs * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
return final_outs return final_outs
...@@ -1678,26 +1680,36 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call, ...@@ -1678,26 +1680,36 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time, apply_cimpl, message, outputs_size, apply_time, apply_cimpl, message, outputs_size,
other_time): other_time):
# Scan overhead profile # Scan overhead profile
if any([isinstance(node.op, Scan) and v>0 for (_,node),v in if any([isinstance(node.op, Scan) and v > 0 for (_, node), v in
apply_time.items()]): apply_time.items()]):
print print
print 'Scan overhead:' print 'Scan overhead:'
print '<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(% scan op time)> <sub scan op time(% scan op time)> <node>' print ('<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(% scan op time)> <sub scan '
'op time(% scan op time)> <node>')
total_super_scan_time = 0 total_super_scan_time = 0
total_scan_fct_time = 0 total_scan_fct_time = 0
total_scan_op_time = 0 total_scan_op_time = 0
for (_,node),v in apply_time.items(): for (_, node), v in apply_time.items():
if isinstance(node.op, Scan): if isinstance(node.op, Scan):
if v> 0: if v > 0:
scan_fct_time = node.op.mode_instance.fn_time scan_fct_time = node.op.mode_instance.fn_time
scan_op_time = node.op.mode_instance.local_time scan_op_time = node.op.mode_instance.local_time
total_super_scan_time += v total_super_scan_time += v
total_scan_fct_time += scan_fct_time total_scan_fct_time += scan_fct_time
total_scan_op_time += scan_op_time total_scan_op_time += scan_op_time
print ' %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%'%( print ' %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%' % (
v, scan_fct_time, scan_op_time, scan_fct_time/v*100, v,
scan_op_time/v*100), node scan_fct_time,
scan_op_time,
scan_fct_time / v * 100,
scan_op_time / v * 100), node
else: else:
print ' The node took 0s, so we can not compute the overhead', node print (' The node took 0s, so we can not '
print ' total %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%'%( 'compute the overhead'), node
total_super_scan_time, total_scan_fct_time, total_scan_op_time, total_scan_fct_time/total_super_scan_time*100, total_scan_op_time/total_super_scan_time*100) print ' total %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%' % (
total_super_scan_time,
total_scan_fct_time,
total_scan_op_time,
total_scan_fct_time / total_super_scan_time * 100,
total_scan_op_time / total_super_scan_time * 100)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论