提交 39a2eb78 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

PEP 8 fixes

上级 5001661d
...@@ -334,14 +334,14 @@ class Scan(PureOp): ...@@ -334,14 +334,14 @@ class Scan(PureOp):
inner_mitsots[ipos + k].type.ndim)) inner_mitsots[ipos + k].type.ndim))
ipos += len(itaps) ipos += len(itaps)
if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or
inner_mitsot_out.ndim != outer_mitsot.ndim - 1): inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_mitsot), (str(outer_mitsot),
argoffset + idx, argoffset + idx,
outer_mitsot.type.dtype, outer_mitsot.type.dtype,
outer_mitsot.type.ndim, outer_mitsot.type.ndim,
inner_mitsot_out.type.dtype, inner_mitsot_out.type.dtype,
inner_mitsot_out.type.ndim)) inner_mitsot_out.type.ndim))
argoffset += len(self.outer_mitsot(inputs)) argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot # Same checks as above but for outputs of type sit_sot
...@@ -353,22 +353,22 @@ class Scan(PureOp): ...@@ -353,22 +353,22 @@ class Scan(PureOp):
inner_sitsot.ndim != outer_sitsot.ndim - 1): inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_sitsot), str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.type.dtype, outer_sitsot.type.dtype,
outer_sitsot.type.ndim, outer_sitsot.type.ndim,
str(inner_sitsot), str(inner_sitsot),
inner_sitsot.type.dtype, inner_sitsot.type.dtype,
inner_sitsot.type.ndim)) inner_sitsot.type.ndim))
if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or
inner_sitsot_out.ndim != outer_sitsot.ndim - 1): inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_sitsot), (str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.type.dtype, outer_sitsot.type.dtype,
outer_sitsot.type.ndim, outer_sitsot.type.ndim,
inner_sitsot_out.type.dtype, inner_sitsot_out.type.dtype,
inner_sitsot_out.type.ndim)) inner_sitsot_out.type.ndim))
argoffset += len(self.outer_sitsot(inputs)) argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
...@@ -1237,7 +1237,6 @@ class Scan(PureOp): ...@@ -1237,7 +1237,6 @@ class Scan(PureOp):
else: else:
return -1 return -1
def get_output_slice_idx(self, output_index): def get_output_slice_idx(self, output_index):
ipos = 0 ipos = 0
opos = output_index opos = output_index
...@@ -1287,6 +1286,7 @@ class Scan(PureOp): ...@@ -1287,6 +1286,7 @@ class Scan(PureOp):
dC_dXts = [] dC_dXts = []
Xts = [] Xts = []
for idx, Xt in enumerate(diff_outputs): for idx, Xt in enumerate(diff_outputs):
# We are looking for x[t-1] for a given x[t] # We are looking for x[t-1] for a given x[t]
if idx >= self.n_mit_mot_outs: if idx >= self.n_mit_mot_outs:
Xt_placeholder = Xt.type() Xt_placeholder = Xt.type()
...@@ -1333,7 +1333,7 @@ class Scan(PureOp): ...@@ -1333,7 +1333,7 @@ class Scan(PureOp):
# construct dX_dtm1 # construct dX_dtm1
dC_dXtm1s = [x.type() for x in dC_dinps_t[self.n_seqs:]] dC_dXtm1s = [x.type() for x in dC_dinps_t[self.n_seqs:]]
for dx, dC_dXtm1 in enumerate(dC_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
dC_dinps_t[dx+self.n_seqs] += dC_dXtm1 dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
# Seqs # Seqs
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]] outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
...@@ -1399,7 +1399,7 @@ class Scan(PureOp): ...@@ -1399,7 +1399,7 @@ class Scan(PureOp):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dC_dinps_t[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
if not disconnected_dC_dinps_t[ins_pos]: if not disconnected_dC_dinps_t[ins_pos]:
disconnected=False disconnected = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True undefined = True
...@@ -1429,11 +1429,11 @@ class Scan(PureOp): ...@@ -1429,11 +1429,11 @@ class Scan(PureOp):
mitmot_inp_taps[idx + offset].append(0) mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])): for jdx in xrange(len(self.tap_array[idx_tap])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dC_dinps_t[ins_pos])
mitmot_inp_taps[idx + offset].append( mitmot_inp_taps[idx + offset].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
mitmot_out_taps[idx].append( mitmot_out_taps[idx].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if not disconnected_dC_dinps_t[ins_pos]: if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False disconnected = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
...@@ -1459,7 +1459,7 @@ class Scan(PureOp): ...@@ -1459,7 +1459,7 @@ class Scan(PureOp):
else: else:
outer_inp_mitmot.append( outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape, tensor.zeros(outs[idx + offset].shape,
dtype = dC_dinps_t[ins_pos].dtype)) dtype=dC_dinps_t[ins_pos].dtype))
inner_out_mitmot.append(dC_dinps_t[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
...@@ -1511,7 +1511,7 @@ class Scan(PureOp): ...@@ -1511,7 +1511,7 @@ class Scan(PureOp):
outer_inp_sitsot = [ outer_inp_sitsot = [
tensor.zeros([grad_steps + 1] + tensor.zeros([grad_steps + 1] +
[x.shape[i] for i in xrange(x.ndim)], [x.shape[i] for i in xrange(x.ndim)],
dtype = y.dtype) dtype=y.dtype)
for y, x in zip(inner_inp_sitsot, for y, x in zip(inner_inp_sitsot,
self.outer_non_seqs(inputs))] self.outer_non_seqs(inputs))]
......
...@@ -23,7 +23,8 @@ from theano import gof ...@@ -23,7 +23,8 @@ from theano import gof
from theano.gof.python25 import maxsize from theano.gof.python25 import maxsize
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import deep_copy_op, optdb from theano.compile import optdb
from theano.compile.function_module import deep_copy_op
import scan_op import scan_op
import scan_utils import scan_utils
...@@ -221,7 +222,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -221,7 +222,7 @@ class PushOutNonSeqScan(gof.Optimizer):
'to move some computation fron scan ' 'to move some computation fron scan '
'which is not allowed to move. Report ' 'which is not allowed to move. Report '
'this on theano-users list'), x) 'this on theano-users list'), x)
outside_ins = [x.type.filter_variable(y) for x,y in outside_ins = [x.type.filter_variable(y) for x, y in
zip(nd.inputs, outside_ins)] zip(nd.inputs, outside_ins)]
nw_outer_node = nd.op.make_node(*outside_ins) nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
...@@ -681,7 +682,10 @@ class ScanSaveMem(gof.Optimizer): ...@@ -681,7 +682,10 @@ class ScanSaveMem(gof.Optimizer):
if (nw_inputs[offset + idx].owner and if (nw_inputs[offset + idx].owner and
isinstance(nw_inputs[offset + idx].owner.op, isinstance(nw_inputs[offset + idx].owner.op,
tensor.IncSubtensor) and tensor.IncSubtensor) and
isinstance(nw_inputs[offset+idx].owner.op.idx_list[0], slice)): isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0],
slice)):
_nw_input = nw_inputs[offset + idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = tensor.as_tensor_variable(val) cval = tensor.as_tensor_variable(val)
initl = tensor.as_tensor_variable(init_l[i]) initl = tensor.as_tensor_variable(init_l[i])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论