提交 8b368cc8 authored 作者: Faruk Ahmed's avatar Faruk Ahmed

flake8 for scan_op

上级 d2aef4d9
...@@ -917,8 +917,7 @@ class Scan(PureOp): ...@@ -917,8 +917,7 @@ class Scan(PureOp):
dtype='int32') dtype='int32')
from . import scan_perform_ext from . import scan_perform_ext
p = lambda node, args, outs:\ p = lambda node, args, outs:\
scan_perform_ext.perform( scan_perform_ext.perform(self.n_shared_outs,
self.n_shared_outs,
self.n_mit_mot_outs, self.n_mit_mot_outs,
self.n_seqs, self.n_seqs,
self.n_mit_mot, self.n_mit_mot,
...@@ -1654,15 +1653,13 @@ class Scan(PureOp): ...@@ -1654,15 +1653,13 @@ class Scan(PureOp):
self_outs = self.outputs[:-1] self_outs = self.outputs[:-1]
else: else:
self_outs = self.outputs self_outs = self.outputs
outs_shape = scan_utils.infer_shape( outs_shape = scan_utils.infer_shape(outs=self_outs,
outs=self_outs,
inputs=self.inputs, inputs=self.inputs,
input_shapes=inner_ins_shapes) input_shapes=inner_ins_shapes)
# Will be used to check if outs_shape can be expressed without using # Will be used to check if outs_shape can be expressed without using
# variables in self.inputs. # variables in self.inputs.
# The shapes of node.inputs are valid. # The shapes of node.inputs are valid.
validator = scan_utils.Validator( validator = scan_utils.Validator(valid=input_shapes,
valid=input_shapes,
invalid=self.inputs, invalid=self.inputs,
valid_equivalent=out_equivalent) valid_equivalent=out_equivalent)
...@@ -1887,18 +1884,18 @@ class Scan(PureOp): ...@@ -1887,18 +1884,18 @@ class Scan(PureOp):
# With the global mapping inferred, the individual mappings # With the global mapping inferred, the individual mappings
# can be produced # can be produced
mappings = {"outer_inp_from_outer_out" : {}, mappings = {"outer_inp_from_outer_out": {},
"inner_inp_from_outer_out" : {}, "inner_inp_from_outer_out": {},
"inner_out_from_outer_out" : {}, "inner_out_from_outer_out": {},
"inner_inp_from_outer_inp" : {}, "inner_inp_from_outer_inp": {},
"inner_out_from_outer_inp" : {}, "inner_out_from_outer_inp": {},
"outer_out_from_outer_inp" : {}, "outer_out_from_outer_inp": {},
"outer_inp_from_inner_inp" : {}, "outer_inp_from_inner_inp": {},
"inner_out_from_inner_inp" : {}, "inner_out_from_inner_inp": {},
"outer_out_from_inner_inp" : {}, "outer_out_from_inner_inp": {},
"outer_inp_from_inner_out" : {}, "outer_inp_from_inner_out": {},
"inner_inp_from_inner_out" : {}, "inner_inp_from_inner_out": {},
"outer_out_from_inner_out" : {}} "outer_out_from_inner_out": {}}
for (oinp, iinp, iout, oout) in izip(outer_input_indices, for (oinp, iinp, iout, oout) in izip(outer_input_indices,
inner_input_indices, inner_input_indices,
...@@ -2031,8 +2028,7 @@ class Scan(PureOp): ...@@ -2031,8 +2028,7 @@ class Scan(PureOp):
# to X. # to X.
known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()]) known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()])
grads = gradient.grad( grads = gradient.grad(cost=None,
cost=None,
known_grads=known_grads, known_grads=known_grads,
wrt=wrt, wrt=wrt,
consider_constant=wrt, consider_constant=wrt,
...@@ -2098,7 +2094,6 @@ class Scan(PureOp): ...@@ -2098,7 +2094,6 @@ class Scan(PureOp):
dC_dXt = safe_new(dC_douts[idx][0]) dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt) dC_dXts.append(dC_dXt)
known_grads = OrderedDict() known_grads = OrderedDict()
dc_dxts_idx = 0 dc_dxts_idx = 0
for i in range(len(diff_outputs)): for i in range(len(diff_outputs)):
...@@ -2180,7 +2175,7 @@ class Scan(PureOp): ...@@ -2180,7 +2175,7 @@ class Scan(PureOp):
seq = outs[idx] seq = outs[idx]
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
if outmaxtap - k != 0: if outmaxtap - k != 0:
nw_seq = seq[k - mintap: -(outmaxtap-k)][::-1] nw_seq = seq[k - mintap: -(outmaxtap - k)][::-1]
else: else:
nw_seq = seq[k - mintap:][::-1] nw_seq = seq[k - mintap:][::-1]
outer_inp_seqs.append(nw_seq) outer_inp_seqs.append(nw_seq)
...@@ -2288,7 +2283,6 @@ class Scan(PureOp): ...@@ -2288,7 +2283,6 @@ class Scan(PureOp):
new_inner_out_mitmot = theano.clone(new_inner_out_mitmot, new_inner_out_mitmot = theano.clone(new_inner_out_mitmot,
replace=[(to_replace, replacement)]) replace=[(to_replace, replacement)])
inner_out_mitmot.append(new_inner_out_mitmot) inner_out_mitmot.append(new_inner_out_mitmot)
if not disconnected_dC_dinps_t[ins_pos]: if not disconnected_dC_dinps_t[ins_pos]:
...@@ -2553,8 +2547,7 @@ class Scan(PureOp): ...@@ -2553,8 +2547,7 @@ class Scan(PureOp):
gradients.append(NullType(t)()) gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate( for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
zip(outputs[:end], type_outs[:end])):
if t == 'connected': if t == 'connected':
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == 'disconnected': elif t == 'disconnected':
...@@ -2591,8 +2584,7 @@ class Scan(PureOp): ...@@ -2591,8 +2584,7 @@ class Scan(PureOp):
begin = end begin = end
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
for p, (x, t) in enumerate( for p, (x, t) in enumerate(zip(outputs[begin:end], type_outs[begin:end])):
zip(outputs[begin:end], type_outs[begin:end])):
if t == 'connected': if t == 'connected':
gradients.append(x[-1]) gradients.append(x[-1])
elif t == 'disconnected': elif t == 'disconnected':
...@@ -2640,8 +2632,7 @@ class Scan(PureOp): ...@@ -2640,8 +2632,7 @@ class Scan(PureOp):
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
if self.info['n_shared_outs'] > 0: if self.info['n_shared_outs'] > 0:
rop_self_outputs = rop_self_outputs[:-self.info['n_shared_outs']] rop_self_outputs = rop_self_outputs[:-self.info['n_shared_outs']]
rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs, rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
inner_eval_points)
if type(rop_outs) not in (list, tuple): if type(rop_outs) not in (list, tuple):
rop_outs = [rop_outs] rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan # Step 2. Figure out what corresponds to what in the scan
...@@ -2721,7 +2712,7 @@ class Scan(PureOp): ...@@ -2721,7 +2712,7 @@ class Scan(PureOp):
e = e + self.n_mit_sot e = e + self.n_mit_sot
ib = ie ib = ie
ie = ie + int(numpy.sum([len(x) for x in ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot:\ self.tap_array[self.n_mit_mot: \
self.n_mit_mot + self.n_mit_sot]])) self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = [] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]): for inp, evp in zip(inputs[b:e], eval_points[b:e]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论