提交 8b368cc8 authored 作者: Faruk Ahmed's avatar Faruk Ahmed

flake8 for scan_op

上级 d2aef4d9
......@@ -125,7 +125,7 @@ class Scan(PureOp):
outputs,
info,
typeConstructor=None,
):
):
if 'gpua' not in info:
info['gpua'] = False
# adding properties into self
......@@ -346,8 +346,8 @@ class Scan(PureOp):
len(self.inner_shared(self.inputs)) +
len(self.inner_non_seqs(self.inputs)))
assert n_outer_ins == n_inner_ins, \
("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.")
("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.")
new_inputs = [inputs[0]]
# assert dtype is consistent
err_msg1 = ('When compiling the inner function of scan (the '
......@@ -372,7 +372,7 @@ class Scan(PureOp):
'have the same dimensionality, you can increase the '
'dimensionality of the varialbe in the initial state of scan '
'by using dimshuffle or shape_padleft. '
)
)
err_msg2 = ('When compiling the inner function of scan the '
'following error has been encountered: The '
'initial state (`outputs_info` in scan nomenclature) '
......@@ -399,7 +399,7 @@ class Scan(PureOp):
'have the same dimensionality, you can increase the '
'dimensionality of the variable in the initial state of scan '
'by using dimshuffle or shape_padleft. '
)
)
def format(var, as_var):
"""
......@@ -440,9 +440,9 @@ class Scan(PureOp):
inner_mitmot = self.inner_mitmot(self.inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(),
self.mitmot_out_taps(),
self.outer_mitmot(inputs))):
zip(self.mitmot_taps(),
self.mitmot_out_taps(),
self.outer_mitmot(inputs))):
outer_mitmot = format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot)
for k in xrange(len(itaps)):
......@@ -450,15 +450,15 @@ class Scan(PureOp):
outer_mitmot.type.dtype or
inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
outer_mitmot.type.ndim,
str(inner_mitmot[ipos + k]),
inner_mitmot[ipos +
k].type.dtype,
inner_mitmot[ipos + k].type.ndim))
' in scan nomenclature) ',
str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
outer_mitmot.type.ndim,
str(inner_mitmot[ipos + k]),
inner_mitmot[ipos +
k].type.dtype,
inner_mitmot[ipos + k].type.ndim))
ipos += len(itaps)
for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos + k].type.dtype !=
......@@ -491,14 +491,14 @@ class Scan(PureOp):
outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitsot),
argoffset + idx,
outer_mitsot.type.dtype,
outer_mitsot.type.ndim,
str(inner_mitsots[ipos + k]),
inner_mitsots[ipos + k].type.dtype,
inner_mitsots[ipos + k].type.ndim))
' in scan nomenclature) ',
str(outer_mitsot),
argoffset + idx,
outer_mitsot.type.dtype,
outer_mitsot.type.ndim,
str(inner_mitsots[ipos + k]),
inner_mitsots[ipos + k].type.dtype,
inner_mitsots[ipos + k].type.ndim))
ipos += len(itaps)
if inner_mitsot_out.type.dtype != outer_mitsot.type.dtype:
raise ValueError(err_msg2 %
......@@ -523,14 +523,14 @@ class Scan(PureOp):
new_inputs.append(outer_sitsot)
if (inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_sitsot),
argoffset + idx,
outer_sitsot.type.dtype,
outer_sitsot.type.ndim,
str(inner_sitsot),
inner_sitsot.type.dtype,
inner_sitsot.type.ndim))
' in scan nomenclature) ',
str(outer_sitsot),
argoffset + idx,
outer_sitsot.type.dtype,
outer_sitsot.type.ndim,
str(inner_sitsot),
inner_sitsot.type.dtype,
inner_sitsot.type.ndim))
if inner_sitsot_out.type.dtype != outer_sitsot.type.dtype:
raise ValueError(err_msg2 %
(str(outer_sitsot),
......@@ -570,14 +570,14 @@ class Scan(PureOp):
(outer_shared.dtype != inner_shared.dtype or
outer_shared.ndim != inner_shared.ndim)):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_shared),
argoffset + idx,
outer_shared.dtype,
outer_shared.ndim,
str(inner_shared),
inner_shared.dtype,
inner_shared.ndim))
' in scan nomenclature) ',
str(outer_shared),
argoffset + idx,
outer_shared.dtype,
outer_shared.ndim,
str(inner_shared),
inner_shared.dtype,
inner_shared.ndim))
# We do not need to call `format` on outer_nisot arguments.
# outer_nitsot stands for no input tap single output tap. This means
# these are states that do not feed anything back in the recurrent
......@@ -595,7 +595,7 @@ class Scan(PureOp):
if inner_nonseq.type != outer_nonseq.type:
raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s') %
(str(outer_nonseq), str(inner_nonseq)))
(str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that
......@@ -788,7 +788,7 @@ class Scan(PureOp):
# Wrap the corresponding input as usual. Leave the
# output as-is.
wrapped_inputs.append(In(self.inputs[input_idx],
borrow=False))
borrow=False))
input_idx += 1
# Wrap the inputs not associated to mitmots and wrap the remaining
......@@ -841,7 +841,7 @@ class Scan(PureOp):
profile = None
if (theano.config.profile or
(isinstance(self.profile, (string_types, bool, integer_types))
and self.profile)):
and self.profile)):
if isinstance(self.profile, string_types):
profile = ScanProfileStats(name=self.profile)
else:
......@@ -866,7 +866,7 @@ class Scan(PureOp):
for out in self.fn.maker.fgraph.outputs]
try:
if impl == 'py':
if impl == 'py':
raise theano.gof.cmodule.MissingGXX
cython_mintaps = numpy.asarray(self.mintaps, dtype='int32')
cython_tap_array_len = \
......@@ -890,16 +890,16 @@ class Scan(PureOp):
d1 = numpy.max(cython_mit_mot_out_nslices)
d0 = len(self.mit_mot_out_slices)
cython_mit_mot_out_slices = numpy.zeros((d0, d1),
dtype='int32')
dtype='int32')
for _d0 in xrange(d0):
for _d1 in xrange(cython_mit_mot_out_nslices[_d0]):
cython_mit_mot_out_slices[_d0, _d1] = \
self.mit_mot_out_slices[_d0][_d1]
cython_vector_seqs = numpy.asarray(self.vector_seqs,
dtype='int32')
dtype='int32')
cython_vector_outs = numpy.asarray(self.vector_outs,
dtype='int32')
dtype='int32')
cython_mitmots_preallocated = numpy.asarray(self.mitmots_preallocated,
dtype='int32')
......@@ -910,39 +910,38 @@ class Scan(PureOp):
if hasattr(self, 'destroy_map'):
cython_destroy_map = [x in self.destroy_map
for x in xrange(len(node.outputs))]
for x in xrange(len(node.outputs))]
else:
cython_destroy_map = [0 for x in xrange(len(node.outputs))]
cython_destroy_map = numpy.asarray(cython_destroy_map,
dtype='int32')
from . import scan_perform_ext
p = lambda node, args, outs:\
scan_perform_ext.perform(
self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
args[0],
self.as_while,
cython_mintaps,
cython_tap_array,
cython_tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
cython_mit_mot_out_nslices,
cython_mitmots_preallocated,
cython_inps_is_tensor,
cython_outs_is_tensor,
self.fn.fn,
self.fn,
cython_destroy_map,
args,
outs,
self, node)
scan_perform_ext.perform(self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
args[0],
self.as_while,
cython_mintaps,
cython_tap_array,
cython_tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
cython_mit_mot_out_nslices,
cython_mitmots_preallocated,
cython_inps_is_tensor,
cython_outs_is_tensor,
self.fn.fn,
self.fn,
cython_destroy_map,
args,
outs,
self, node)
except (ImportError, theano.gof.cmodule.MissingGXX):
p = self.execute
# default arguments are stored in the closure of `rval`
......@@ -1004,8 +1003,8 @@ class Scan(PureOp):
def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
ntaps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
return list_inputs[self.n_seqs + n_mitmot_taps:
self.n_seqs + ntaps_upto_sit_sot]
......@@ -1094,7 +1093,7 @@ class Scan(PureOp):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot)
self.n_nit_sot)
return list_outputs[offset:offset + self.n_shared_outs]
def inner_non_seqs(self, list_inputs):
......@@ -1153,10 +1152,10 @@ class Scan(PureOp):
for idx, seq in enumerate(args[1:self.seqs_arg_offset]):
if seq.shape[0] < n_steps:
raise ValueError(('Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps,
node.inputs[1 + idx],
seq.shape)
node.inputs[1 + idx],
seq.shape)
seqs.append(seq)
# 2. Allocate memory for the outputs. Construct the list:
......@@ -1165,15 +1164,15 @@ class Scan(PureOp):
# output
store_steps = [arg.shape[0] for arg
in args[self.seqs_arg_offset:
self.shared_arg_offset]]
in args[self.seqs_arg_offset:
self.shared_arg_offset]]
store_steps += [arg for arg in
args[self.nit_sot_arg_offset:
self.nit_sot_arg_offset + self.n_nit_sot]
]
args[self.nit_sot_arg_offset:
self.nit_sot_arg_offset + self.n_nit_sot]
]
pos = [(-self.mintaps[idx]) % store_steps[idx] for idx
in xrange(self.n_outs + self.n_nit_sot)]
in xrange(self.n_outs + self.n_nit_sot)]
if not getattr(self, 'destroy_map', None):
self.destroy_map = OrderedDict()
# 2.1 Create storage space for outputs
......@@ -1207,7 +1206,7 @@ class Scan(PureOp):
old_output_data = [None] * len(output_storage)
fn = self.fn.fn
offset = (self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) +
self.n_shared_outs)
self.n_shared_outs)
for idx in xrange(len(other_args)):
input_storage[idx + offset].storage[0] = other_args[idx]
......@@ -1221,7 +1220,7 @@ class Scan(PureOp):
for idx in xrange(self.n_seqs):
if self.vector_seqs[idx]:
input_storage[idx].storage[0] = \
seqs[idx][i:i + 1].reshape(())
seqs[idx][i:i + 1].reshape(())
else:
input_storage[idx].storage[0] = seqs[idx][i]
......@@ -1231,7 +1230,7 @@ class Scan(PureOp):
for tap in self.tap_array[idx]:
_idx = (pos[idx] + tap) % store_steps[idx]
input_storage[offset].storage[0] =\
outs[idx][0][_idx:_idx + 1].reshape(())
outs[idx][0][_idx:_idx + 1].reshape(())
offset += 1
else:
for tap in self.tap_array[idx]:
......@@ -1400,7 +1399,7 @@ class Scan(PureOp):
# This output tap has not been preallocated, recover
# its value as usual
outs[j][0][k + pos[j]] = \
output_storage[offset_out].storage[0]
output_storage[offset_out].storage[0]
offset_out += 1
mitmot_out_idx += 1
......@@ -1417,7 +1416,7 @@ class Scan(PureOp):
# Copy the output value to `outs`, if necessary
if store_steps[j] == 1 or self.vector_outs[j]:
outs[j][0][pos[j]] = \
output_storage[offset_out + j].storage[0]
output_storage[offset_out + j].storage[0]
else:
# Check whether the initialization of the output storage
# map for this output has been reused.
......@@ -1446,7 +1445,7 @@ class Scan(PureOp):
if i == 0:
jout = j + offset_out
shape = (store_steps[j],) + \
output_storage[jout].storage[0].shape
output_storage[jout].storage[0].shape
if len(output_storage[jout].storage[0].shape) == 0:
self.vector_outs[j] = True
dtype = output_storage[jout].storage[0].dtype
......@@ -1490,7 +1489,7 @@ class Scan(PureOp):
outs[j][0] = output_storage[jout].storage[0]
pos = [(idx + 1) % store for idx, store in
izip(pos, store_steps)]
izip(pos, store_steps)]
i = i + 1
# 6. Check if you need to re-order output buffers
......@@ -1654,17 +1653,15 @@ class Scan(PureOp):
self_outs = self.outputs[:-1]
else:
self_outs = self.outputs
outs_shape = scan_utils.infer_shape(
outs=self_outs,
inputs=self.inputs,
input_shapes=inner_ins_shapes)
outs_shape = scan_utils.infer_shape(outs=self_outs,
inputs=self.inputs,
input_shapes=inner_ins_shapes)
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs.
# The shapes of node.inputs are valid.
validator = scan_utils.Validator(
valid=input_shapes,
invalid=self.inputs,
valid_equivalent=out_equivalent)
validator = scan_utils.Validator(valid=input_shapes,
invalid=self.inputs,
valid_equivalent=out_equivalent)
offset = 1 + self.n_seqs
scan_outs = [x for x in input_shapes[offset:offset + n_outs]]
......@@ -1699,7 +1696,7 @@ class Scan(PureOp):
scan_outs.append(tuple(shp))
scan_outs += [x for x in
input_shapes[offset:offset + self.n_shared_outs]]
input_shapes[offset:offset + self.n_shared_outs]]
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if self.as_while:
......@@ -1763,7 +1760,7 @@ class Scan(PureOp):
j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx]
if j_inp_idx != -1:
if connection_pattern[j_inp_idx][iidx] == True:
if connection_pattern[j_inp_idx][iidx] == True:
for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True
......@@ -1887,18 +1884,18 @@ class Scan(PureOp):
# With the global mapping inferred, the individual mappings
# can be produced
mappings = {"outer_inp_from_outer_out" : {},
"inner_inp_from_outer_out" : {},
"inner_out_from_outer_out" : {},
"inner_inp_from_outer_inp" : {},
"inner_out_from_outer_inp" : {},
"outer_out_from_outer_inp" : {},
"outer_inp_from_inner_inp" : {},
"inner_out_from_inner_inp" : {},
"outer_out_from_inner_inp" : {},
"outer_inp_from_inner_out" : {},
"inner_inp_from_inner_out" : {},
"outer_out_from_inner_out" : {}}
mappings = {"outer_inp_from_outer_out": {},
"inner_inp_from_outer_out": {},
"inner_out_from_outer_out": {},
"inner_inp_from_outer_inp": {},
"inner_out_from_outer_inp": {},
"outer_out_from_outer_inp": {},
"outer_inp_from_inner_inp": {},
"inner_out_from_inner_inp": {},
"outer_out_from_inner_inp": {},
"outer_inp_from_inner_out": {},
"inner_inp_from_inner_out": {},
"outer_out_from_inner_out": {}}
for (oinp, iinp, iout, oout) in izip(outer_input_indices,
inner_input_indices,
......@@ -1944,7 +1941,7 @@ class Scan(PureOp):
grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1
elif self.n_mit_sot > 0:
grad_steps = self.outer_mitsot_outs(outs)[0].shape[0] +\
self.mintaps[self.n_mit_mot]
self.mintaps[self.n_mit_mot]
else:
grad_steps = inputs[0]
......@@ -2031,14 +2028,13 @@ class Scan(PureOp):
# to X.
known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()])
grads = gradient.grad(
cost=None,
known_grads=known_grads,
wrt=wrt,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None',
null_gradients='return')
grads = gradient.grad(cost=None,
known_grads=known_grads,
wrt=wrt,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None',
null_gradients='return')
for i in range(len(wrt)):
gmp[wrt[i]] = grads[i]
......@@ -2098,7 +2094,6 @@ class Scan(PureOp):
dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt)
known_grads = OrderedDict()
dc_dxts_idx = 0
for i in range(len(diff_outputs)):
......@@ -2153,7 +2148,7 @@ class Scan(PureOp):
dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype)
x.astype(dC_dXts[opos].dtype)
else:
dC_dXtm1s.append(safe_new(x))
......@@ -2180,7 +2175,7 @@ class Scan(PureOp):
seq = outs[idx]
for k in self.tap_array[idx]:
if outmaxtap - k != 0:
nw_seq = seq[k - mintap: -(outmaxtap-k)][::-1]
nw_seq = seq[k - mintap: -(outmaxtap - k)][::-1]
else:
nw_seq = seq[k - mintap:][::-1]
outer_inp_seqs.append(nw_seq)
......@@ -2288,7 +2283,6 @@ class Scan(PureOp):
new_inner_out_mitmot = theano.clone(new_inner_out_mitmot,
replace=[(to_replace, replacement)])
inner_out_mitmot.append(new_inner_out_mitmot)
if not disconnected_dC_dinps_t[ins_pos]:
......@@ -2553,8 +2547,7 @@ class Scan(PureOp):
gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])):
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
......@@ -2587,12 +2580,11 @@ class Scan(PureOp):
start = len(gradients)
gradients += [DisconnectedType()()
for x in xrange(self.n_nit_sot)]
for x in xrange(self.n_nit_sot)]
begin = end
end = begin + n_sitsot_outs
for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])):
for p, (x, t) in enumerate(zip(outputs[begin:end], type_outs[begin:end])):
if t == 'connected':
gradients.append(x[-1])
elif t == 'disconnected':
......@@ -2629,7 +2621,7 @@ class Scan(PureOp):
self.outputs, '_rop')
self_inputs = rval[0]
rop_of_inputs = rval[0][:self.n_seqs + self.n_outs] + \
rval[0][self.n_seqs + self.n_outs + self.n_shared_outs:]
rval[0][self.n_seqs + self.n_outs + self.n_shared_outs:]
self_outputs = rval[1]
# Step 1. Compute the R_op of the inner function
inner_eval_points = [scan_utils.safe_new(x, '_evalpoint')
......@@ -2640,8 +2632,7 @@ class Scan(PureOp):
rop_self_outputs = self_outputs
if self.info['n_shared_outs'] > 0:
rop_self_outputs = rop_self_outputs[:-self.info['n_shared_outs']]
rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs,
inner_eval_points)
rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
if type(rop_outs) not in (list, tuple):
rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan
......@@ -2721,8 +2712,8 @@ class Scan(PureOp):
e = e + self.n_mit_sot
ib = ie
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot:\
self.n_mit_mot + self.n_mit_sot]]))
self.tap_array[self.n_mit_mot: \
self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论