提交 8b368cc8 authored 作者: Faruk Ahmed's avatar Faruk Ahmed

flake8 for scan_op

上级 d2aef4d9
......@@ -917,8 +917,7 @@ class Scan(PureOp):
dtype='int32')
from . import scan_perform_ext
p = lambda node, args, outs:\
scan_perform_ext.perform(
self.n_shared_outs,
scan_perform_ext.perform(self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
......@@ -1654,15 +1653,13 @@ class Scan(PureOp):
self_outs = self.outputs[:-1]
else:
self_outs = self.outputs
outs_shape = scan_utils.infer_shape(
outs=self_outs,
outs_shape = scan_utils.infer_shape(outs=self_outs,
inputs=self.inputs,
input_shapes=inner_ins_shapes)
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs.
# The shapes of node.inputs are valid.
validator = scan_utils.Validator(
valid=input_shapes,
validator = scan_utils.Validator(valid=input_shapes,
invalid=self.inputs,
valid_equivalent=out_equivalent)
......@@ -1887,18 +1884,18 @@ class Scan(PureOp):
# With the global mapping inferred, the individual mappings
# can be produced
mappings = {"outer_inp_from_outer_out" : {},
"inner_inp_from_outer_out" : {},
"inner_out_from_outer_out" : {},
"inner_inp_from_outer_inp" : {},
"inner_out_from_outer_inp" : {},
"outer_out_from_outer_inp" : {},
"outer_inp_from_inner_inp" : {},
"inner_out_from_inner_inp" : {},
"outer_out_from_inner_inp" : {},
"outer_inp_from_inner_out" : {},
"inner_inp_from_inner_out" : {},
"outer_out_from_inner_out" : {}}
mappings = {"outer_inp_from_outer_out": {},
"inner_inp_from_outer_out": {},
"inner_out_from_outer_out": {},
"inner_inp_from_outer_inp": {},
"inner_out_from_outer_inp": {},
"outer_out_from_outer_inp": {},
"outer_inp_from_inner_inp": {},
"inner_out_from_inner_inp": {},
"outer_out_from_inner_inp": {},
"outer_inp_from_inner_out": {},
"inner_inp_from_inner_out": {},
"outer_out_from_inner_out": {}}
for (oinp, iinp, iout, oout) in izip(outer_input_indices,
inner_input_indices,
......@@ -2031,8 +2028,7 @@ class Scan(PureOp):
# to X.
known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()])
grads = gradient.grad(
cost=None,
grads = gradient.grad(cost=None,
known_grads=known_grads,
wrt=wrt,
consider_constant=wrt,
......@@ -2098,7 +2094,6 @@ class Scan(PureOp):
dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt)
known_grads = OrderedDict()
dc_dxts_idx = 0
for i in range(len(diff_outputs)):
......@@ -2180,7 +2175,7 @@ class Scan(PureOp):
seq = outs[idx]
for k in self.tap_array[idx]:
if outmaxtap - k != 0:
nw_seq = seq[k - mintap: -(outmaxtap-k)][::-1]
nw_seq = seq[k - mintap: -(outmaxtap - k)][::-1]
else:
nw_seq = seq[k - mintap:][::-1]
outer_inp_seqs.append(nw_seq)
......@@ -2288,7 +2283,6 @@ class Scan(PureOp):
new_inner_out_mitmot = theano.clone(new_inner_out_mitmot,
replace=[(to_replace, replacement)])
inner_out_mitmot.append(new_inner_out_mitmot)
if not disconnected_dC_dinps_t[ins_pos]:
......@@ -2553,8 +2547,7 @@ class Scan(PureOp):
gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])):
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
if t == 'connected':
gradients.append(x[::-1])
elif t == 'disconnected':
......@@ -2591,8 +2584,7 @@ class Scan(PureOp):
begin = end
end = begin + n_sitsot_outs
for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])):
for p, (x, t) in enumerate(zip(outputs[begin:end], type_outs[begin:end])):
if t == 'connected':
gradients.append(x[-1])
elif t == 'disconnected':
......@@ -2640,8 +2632,7 @@ class Scan(PureOp):
rop_self_outputs = self_outputs
if self.info['n_shared_outs'] > 0:
rop_self_outputs = rop_self_outputs[:-self.info['n_shared_outs']]
rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs,
inner_eval_points)
rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
if type(rop_outs) not in (list, tuple):
rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan
......@@ -2721,7 +2712,7 @@ class Scan(PureOp):
e = e + self.n_mit_sot
ib = ie
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot:\
self.tap_array[self.n_mit_mot: \
self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论