提交 8b368cc8 authored 作者: Faruk Ahmed's avatar Faruk Ahmed

flake8 for scan_op

上级 d2aef4d9
...@@ -125,7 +125,7 @@ class Scan(PureOp): ...@@ -125,7 +125,7 @@ class Scan(PureOp):
outputs, outputs,
info, info,
typeConstructor=None, typeConstructor=None,
): ):
if 'gpua' not in info: if 'gpua' not in info:
info['gpua'] = False info['gpua'] = False
# adding properties into self # adding properties into self
...@@ -346,8 +346,8 @@ class Scan(PureOp): ...@@ -346,8 +346,8 @@ class Scan(PureOp):
len(self.inner_shared(self.inputs)) + len(self.inner_shared(self.inputs)) +
len(self.inner_non_seqs(self.inputs))) len(self.inner_non_seqs(self.inputs)))
assert n_outer_ins == n_inner_ins, \ assert n_outer_ins == n_inner_ins, \
("The number of inputs given to the inner function of scan" ("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.") " does not match the number of inputs given to scan.")
new_inputs = [inputs[0]] new_inputs = [inputs[0]]
# assert dtype is consistent # assert dtype is consistent
err_msg1 = ('When compiling the inner function of scan (the ' err_msg1 = ('When compiling the inner function of scan (the '
...@@ -372,7 +372,7 @@ class Scan(PureOp): ...@@ -372,7 +372,7 @@ class Scan(PureOp):
'have the same dimensionality, you can increase the ' 'have the same dimensionality, you can increase the '
'dimensionality of the varialbe in the initial state of scan ' 'dimensionality of the varialbe in the initial state of scan '
'by using dimshuffle or shape_padleft. ' 'by using dimshuffle or shape_padleft. '
) )
err_msg2 = ('When compiling the inner function of scan the ' err_msg2 = ('When compiling the inner function of scan the '
'following error has been encountered: The ' 'following error has been encountered: The '
'initial state (`outputs_info` in scan nomenclature) ' 'initial state (`outputs_info` in scan nomenclature) '
...@@ -399,7 +399,7 @@ class Scan(PureOp): ...@@ -399,7 +399,7 @@ class Scan(PureOp):
'have the same dimensionality, you can increase the ' 'have the same dimensionality, you can increase the '
'dimensionality of the variable in the initial state of scan ' 'dimensionality of the variable in the initial state of scan '
'by using dimshuffle or shape_padleft. ' 'by using dimshuffle or shape_padleft. '
) )
def format(var, as_var): def format(var, as_var):
""" """
...@@ -440,9 +440,9 @@ class Scan(PureOp): ...@@ -440,9 +440,9 @@ class Scan(PureOp):
inner_mitmot = self.inner_mitmot(self.inputs) inner_mitmot = self.inner_mitmot(self.inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.outputs) inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
for idx, (itaps, otaps, _outer_mitmot) in enumerate( for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), zip(self.mitmot_taps(),
self.mitmot_out_taps(), self.mitmot_out_taps(),
self.outer_mitmot(inputs))): self.outer_mitmot(inputs))):
outer_mitmot = format(_outer_mitmot, as_var=inner_mitmot[ipos]) outer_mitmot = format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot) new_inputs.append(outer_mitmot)
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
...@@ -450,15 +450,15 @@ class Scan(PureOp): ...@@ -450,15 +450,15 @@ class Scan(PureOp):
outer_mitmot.type.dtype or outer_mitmot.type.dtype or
inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1): inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitmot), str(outer_mitmot),
argoffset + idx, argoffset + idx,
outer_mitmot.type.dtype, outer_mitmot.type.dtype,
outer_mitmot.type.ndim, outer_mitmot.type.ndim,
str(inner_mitmot[ipos + k]), str(inner_mitmot[ipos + k]),
inner_mitmot[ipos + inner_mitmot[ipos +
k].type.dtype, k].type.dtype,
inner_mitmot[ipos + k].type.ndim)) inner_mitmot[ipos + k].type.ndim))
ipos += len(itaps) ipos += len(itaps)
for k in xrange(len(otaps)): for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos + k].type.dtype != if (inner_mitmot_outs[opos + k].type.dtype !=
...@@ -491,14 +491,14 @@ class Scan(PureOp): ...@@ -491,14 +491,14 @@ class Scan(PureOp):
outer_mitsot.type.dtype or outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1): inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitsot), str(outer_mitsot),
argoffset + idx, argoffset + idx,
outer_mitsot.type.dtype, outer_mitsot.type.dtype,
outer_mitsot.type.ndim, outer_mitsot.type.ndim,
str(inner_mitsots[ipos + k]), str(inner_mitsots[ipos + k]),
inner_mitsots[ipos + k].type.dtype, inner_mitsots[ipos + k].type.dtype,
inner_mitsots[ipos + k].type.ndim)) inner_mitsots[ipos + k].type.ndim))
ipos += len(itaps) ipos += len(itaps)
if inner_mitsot_out.type.dtype != outer_mitsot.type.dtype: if inner_mitsot_out.type.dtype != outer_mitsot.type.dtype:
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
...@@ -523,14 +523,14 @@ class Scan(PureOp): ...@@ -523,14 +523,14 @@ class Scan(PureOp):
new_inputs.append(outer_sitsot) new_inputs.append(outer_sitsot)
if (inner_sitsot.ndim != outer_sitsot.ndim - 1): if (inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_sitsot), str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.type.dtype, outer_sitsot.type.dtype,
outer_sitsot.type.ndim, outer_sitsot.type.ndim,
str(inner_sitsot), str(inner_sitsot),
inner_sitsot.type.dtype, inner_sitsot.type.dtype,
inner_sitsot.type.ndim)) inner_sitsot.type.ndim))
if inner_sitsot_out.type.dtype != outer_sitsot.type.dtype: if inner_sitsot_out.type.dtype != outer_sitsot.type.dtype:
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_sitsot), (str(outer_sitsot),
...@@ -570,14 +570,14 @@ class Scan(PureOp): ...@@ -570,14 +570,14 @@ class Scan(PureOp):
(outer_shared.dtype != inner_shared.dtype or (outer_shared.dtype != inner_shared.dtype or
outer_shared.ndim != inner_shared.ndim)): outer_shared.ndim != inner_shared.ndim)):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_shared), str(outer_shared),
argoffset + idx, argoffset + idx,
outer_shared.dtype, outer_shared.dtype,
outer_shared.ndim, outer_shared.ndim,
str(inner_shared), str(inner_shared),
inner_shared.dtype, inner_shared.dtype,
inner_shared.ndim)) inner_shared.ndim))
# We do not need to call `format` on outer_nisot arguments. # We do not need to call `format` on outer_nisot arguments.
# outer_nitsot stands for no input tap single output tap. This means # outer_nitsot stands for no input tap single output tap. This means
# these are states that do not feed anything back in the recurrent # these are states that do not feed anything back in the recurrent
...@@ -595,7 +595,7 @@ class Scan(PureOp): ...@@ -595,7 +595,7 @@ class Scan(PureOp):
if inner_nonseq.type != outer_nonseq.type: if inner_nonseq.type != outer_nonseq.type:
raise ValueError(('Argument %s given to scan node does not' raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s') % ' match its correspondance %s') %
(str(outer_nonseq), str(inner_nonseq))) (str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs): for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that # For every nit_sot input we get as input a int/uint that
...@@ -788,7 +788,7 @@ class Scan(PureOp): ...@@ -788,7 +788,7 @@ class Scan(PureOp):
# Wrap the corresponding input as usual. Leave the # Wrap the corresponding input as usual. Leave the
# output as-is. # output as-is.
wrapped_inputs.append(In(self.inputs[input_idx], wrapped_inputs.append(In(self.inputs[input_idx],
borrow=False)) borrow=False))
input_idx += 1 input_idx += 1
# Wrap the inputs not associated to mitmots and wrap the remaining # Wrap the inputs not associated to mitmots and wrap the remaining
...@@ -841,7 +841,7 @@ class Scan(PureOp): ...@@ -841,7 +841,7 @@ class Scan(PureOp):
profile = None profile = None
if (theano.config.profile or if (theano.config.profile or
(isinstance(self.profile, (string_types, bool, integer_types)) (isinstance(self.profile, (string_types, bool, integer_types))
and self.profile)): and self.profile)):
if isinstance(self.profile, string_types): if isinstance(self.profile, string_types):
profile = ScanProfileStats(name=self.profile) profile = ScanProfileStats(name=self.profile)
else: else:
...@@ -866,7 +866,7 @@ class Scan(PureOp): ...@@ -866,7 +866,7 @@ class Scan(PureOp):
for out in self.fn.maker.fgraph.outputs] for out in self.fn.maker.fgraph.outputs]
try: try:
if impl == 'py': if impl == 'py':
raise theano.gof.cmodule.MissingGXX raise theano.gof.cmodule.MissingGXX
cython_mintaps = numpy.asarray(self.mintaps, dtype='int32') cython_mintaps = numpy.asarray(self.mintaps, dtype='int32')
cython_tap_array_len = \ cython_tap_array_len = \
...@@ -890,16 +890,16 @@ class Scan(PureOp): ...@@ -890,16 +890,16 @@ class Scan(PureOp):
d1 = numpy.max(cython_mit_mot_out_nslices) d1 = numpy.max(cython_mit_mot_out_nslices)
d0 = len(self.mit_mot_out_slices) d0 = len(self.mit_mot_out_slices)
cython_mit_mot_out_slices = numpy.zeros((d0, d1), cython_mit_mot_out_slices = numpy.zeros((d0, d1),
dtype='int32') dtype='int32')
for _d0 in xrange(d0): for _d0 in xrange(d0):
for _d1 in xrange(cython_mit_mot_out_nslices[_d0]): for _d1 in xrange(cython_mit_mot_out_nslices[_d0]):
cython_mit_mot_out_slices[_d0, _d1] = \ cython_mit_mot_out_slices[_d0, _d1] = \
self.mit_mot_out_slices[_d0][_d1] self.mit_mot_out_slices[_d0][_d1]
cython_vector_seqs = numpy.asarray(self.vector_seqs, cython_vector_seqs = numpy.asarray(self.vector_seqs,
dtype='int32') dtype='int32')
cython_vector_outs = numpy.asarray(self.vector_outs, cython_vector_outs = numpy.asarray(self.vector_outs,
dtype='int32') dtype='int32')
cython_mitmots_preallocated = numpy.asarray(self.mitmots_preallocated, cython_mitmots_preallocated = numpy.asarray(self.mitmots_preallocated,
dtype='int32') dtype='int32')
...@@ -910,39 +910,38 @@ class Scan(PureOp): ...@@ -910,39 +910,38 @@ class Scan(PureOp):
if hasattr(self, 'destroy_map'): if hasattr(self, 'destroy_map'):
cython_destroy_map = [x in self.destroy_map cython_destroy_map = [x in self.destroy_map
for x in xrange(len(node.outputs))] for x in xrange(len(node.outputs))]
else: else:
cython_destroy_map = [0 for x in xrange(len(node.outputs))] cython_destroy_map = [0 for x in xrange(len(node.outputs))]
cython_destroy_map = numpy.asarray(cython_destroy_map, cython_destroy_map = numpy.asarray(cython_destroy_map,
dtype='int32') dtype='int32')
from . import scan_perform_ext from . import scan_perform_ext
p = lambda node, args, outs:\ p = lambda node, args, outs:\
scan_perform_ext.perform( scan_perform_ext.perform(self.n_shared_outs,
self.n_shared_outs, self.n_mit_mot_outs,
self.n_mit_mot_outs, self.n_seqs,
self.n_seqs, self.n_mit_mot,
self.n_mit_mot, self.n_mit_sot,
self.n_mit_sot, self.n_sit_sot,
self.n_sit_sot, self.n_nit_sot,
self.n_nit_sot, args[0],
args[0], self.as_while,
self.as_while, cython_mintaps,
cython_mintaps, cython_tap_array,
cython_tap_array, cython_tap_array_len,
cython_tap_array_len, cython_vector_seqs,
cython_vector_seqs, cython_vector_outs,
cython_vector_outs, cython_mit_mot_out_slices,
cython_mit_mot_out_slices, cython_mit_mot_out_nslices,
cython_mit_mot_out_nslices, cython_mitmots_preallocated,
cython_mitmots_preallocated, cython_inps_is_tensor,
cython_inps_is_tensor, cython_outs_is_tensor,
cython_outs_is_tensor, self.fn.fn,
self.fn.fn, self.fn,
self.fn, cython_destroy_map,
cython_destroy_map, args,
args, outs,
outs, self, node)
self, node)
except (ImportError, theano.gof.cmodule.MissingGXX): except (ImportError, theano.gof.cmodule.MissingGXX):
p = self.execute p = self.execute
# default arguments are stored in the closure of `rval` # default arguments are stored in the closure of `rval`
...@@ -1004,8 +1003,8 @@ class Scan(PureOp): ...@@ -1004,8 +1003,8 @@ class Scan(PureOp):
def inner_mitsot(self, list_inputs): def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot]) n_mitmot_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
ntaps_upto_sit_sot = sum(len(x) for x in ntaps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
return list_inputs[self.n_seqs + n_mitmot_taps: return list_inputs[self.n_seqs + n_mitmot_taps:
self.n_seqs + ntaps_upto_sit_sot] self.n_seqs + ntaps_upto_sit_sot]
...@@ -1094,7 +1093,7 @@ class Scan(PureOp): ...@@ -1094,7 +1093,7 @@ class Scan(PureOp):
if isinstance(list_outputs, Apply): if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs list_outputs = list_outputs.outputs
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot) self.n_nit_sot)
return list_outputs[offset:offset + self.n_shared_outs] return list_outputs[offset:offset + self.n_shared_outs]
def inner_non_seqs(self, list_inputs): def inner_non_seqs(self, list_inputs):
...@@ -1153,10 +1152,10 @@ class Scan(PureOp): ...@@ -1153,10 +1152,10 @@ class Scan(PureOp):
for idx, seq in enumerate(args[1:self.seqs_arg_offset]): for idx, seq in enumerate(args[1:self.seqs_arg_offset]):
if seq.shape[0] < n_steps: if seq.shape[0] < n_steps:
raise ValueError(('Sequence is shorter then the required ' raise ValueError(('Sequence is shorter then the required '
'number of steps : (n_steps, seq, ' 'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps, 'seq.shape):'), n_steps,
node.inputs[1 + idx], node.inputs[1 + idx],
seq.shape) seq.shape)
seqs.append(seq) seqs.append(seq)
# 2. Allocate memory for the outputs. Construct the list: # 2. Allocate memory for the outputs. Construct the list:
...@@ -1165,15 +1164,15 @@ class Scan(PureOp): ...@@ -1165,15 +1164,15 @@ class Scan(PureOp):
# output # output
store_steps = [arg.shape[0] for arg store_steps = [arg.shape[0] for arg
in args[self.seqs_arg_offset: in args[self.seqs_arg_offset:
self.shared_arg_offset]] self.shared_arg_offset]]
store_steps += [arg for arg in store_steps += [arg for arg in
args[self.nit_sot_arg_offset: args[self.nit_sot_arg_offset:
self.nit_sot_arg_offset + self.n_nit_sot] self.nit_sot_arg_offset + self.n_nit_sot]
] ]
pos = [(-self.mintaps[idx]) % store_steps[idx] for idx pos = [(-self.mintaps[idx]) % store_steps[idx] for idx
in xrange(self.n_outs + self.n_nit_sot)] in xrange(self.n_outs + self.n_nit_sot)]
if not getattr(self, 'destroy_map', None): if not getattr(self, 'destroy_map', None):
self.destroy_map = OrderedDict() self.destroy_map = OrderedDict()
# 2.1 Create storage space for outputs # 2.1 Create storage space for outputs
...@@ -1207,7 +1206,7 @@ class Scan(PureOp): ...@@ -1207,7 +1206,7 @@ class Scan(PureOp):
old_output_data = [None] * len(output_storage) old_output_data = [None] * len(output_storage)
fn = self.fn.fn fn = self.fn.fn
offset = (self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) + offset = (self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) +
self.n_shared_outs) self.n_shared_outs)
for idx in xrange(len(other_args)): for idx in xrange(len(other_args)):
input_storage[idx + offset].storage[0] = other_args[idx] input_storage[idx + offset].storage[0] = other_args[idx]
...@@ -1221,7 +1220,7 @@ class Scan(PureOp): ...@@ -1221,7 +1220,7 @@ class Scan(PureOp):
for idx in xrange(self.n_seqs): for idx in xrange(self.n_seqs):
if self.vector_seqs[idx]: if self.vector_seqs[idx]:
input_storage[idx].storage[0] = \ input_storage[idx].storage[0] = \
seqs[idx][i:i + 1].reshape(()) seqs[idx][i:i + 1].reshape(())
else: else:
input_storage[idx].storage[0] = seqs[idx][i] input_storage[idx].storage[0] = seqs[idx][i]
...@@ -1231,7 +1230,7 @@ class Scan(PureOp): ...@@ -1231,7 +1230,7 @@ class Scan(PureOp):
for tap in self.tap_array[idx]: for tap in self.tap_array[idx]:
_idx = (pos[idx] + tap) % store_steps[idx] _idx = (pos[idx] + tap) % store_steps[idx]
input_storage[offset].storage[0] =\ input_storage[offset].storage[0] =\
outs[idx][0][_idx:_idx + 1].reshape(()) outs[idx][0][_idx:_idx + 1].reshape(())
offset += 1 offset += 1
else: else:
for tap in self.tap_array[idx]: for tap in self.tap_array[idx]:
...@@ -1400,7 +1399,7 @@ class Scan(PureOp): ...@@ -1400,7 +1399,7 @@ class Scan(PureOp):
# This output tap has not been preallocated, recover # This output tap has not been preallocated, recover
# its value as usual # its value as usual
outs[j][0][k + pos[j]] = \ outs[j][0][k + pos[j]] = \
output_storage[offset_out].storage[0] output_storage[offset_out].storage[0]
offset_out += 1 offset_out += 1
mitmot_out_idx += 1 mitmot_out_idx += 1
...@@ -1417,7 +1416,7 @@ class Scan(PureOp): ...@@ -1417,7 +1416,7 @@ class Scan(PureOp):
# Copy the output value to `outs`, if necessary # Copy the output value to `outs`, if necessary
if store_steps[j] == 1 or self.vector_outs[j]: if store_steps[j] == 1 or self.vector_outs[j]:
outs[j][0][pos[j]] = \ outs[j][0][pos[j]] = \
output_storage[offset_out + j].storage[0] output_storage[offset_out + j].storage[0]
else: else:
# Check whether the initialization of the output storage # Check whether the initialization of the output storage
# map for this output has been reused. # map for this output has been reused.
...@@ -1446,7 +1445,7 @@ class Scan(PureOp): ...@@ -1446,7 +1445,7 @@ class Scan(PureOp):
if i == 0: if i == 0:
jout = j + offset_out jout = j + offset_out
shape = (store_steps[j],) + \ shape = (store_steps[j],) + \
output_storage[jout].storage[0].shape output_storage[jout].storage[0].shape
if len(output_storage[jout].storage[0].shape) == 0: if len(output_storage[jout].storage[0].shape) == 0:
self.vector_outs[j] = True self.vector_outs[j] = True
dtype = output_storage[jout].storage[0].dtype dtype = output_storage[jout].storage[0].dtype
...@@ -1490,7 +1489,7 @@ class Scan(PureOp): ...@@ -1490,7 +1489,7 @@ class Scan(PureOp):
outs[j][0] = output_storage[jout].storage[0] outs[j][0] = output_storage[jout].storage[0]
pos = [(idx + 1) % store for idx, store in pos = [(idx + 1) % store for idx, store in
izip(pos, store_steps)] izip(pos, store_steps)]
i = i + 1 i = i + 1
# 6. Check if you need to re-order output buffers # 6. Check if you need to re-order output buffers
...@@ -1654,17 +1653,15 @@ class Scan(PureOp): ...@@ -1654,17 +1653,15 @@ class Scan(PureOp):
self_outs = self.outputs[:-1] self_outs = self.outputs[:-1]
else: else:
self_outs = self.outputs self_outs = self.outputs
outs_shape = scan_utils.infer_shape( outs_shape = scan_utils.infer_shape(outs=self_outs,
outs=self_outs, inputs=self.inputs,
inputs=self.inputs, input_shapes=inner_ins_shapes)
input_shapes=inner_ins_shapes)
# Will be used to check if outs_shape can be expressed without using # Will be used to check if outs_shape can be expressed without using
# variables in self.inputs. # variables in self.inputs.
# The shapes of node.inputs are valid. # The shapes of node.inputs are valid.
validator = scan_utils.Validator( validator = scan_utils.Validator(valid=input_shapes,
valid=input_shapes, invalid=self.inputs,
invalid=self.inputs, valid_equivalent=out_equivalent)
valid_equivalent=out_equivalent)
offset = 1 + self.n_seqs offset = 1 + self.n_seqs
scan_outs = [x for x in input_shapes[offset:offset + n_outs]] scan_outs = [x for x in input_shapes[offset:offset + n_outs]]
...@@ -1699,7 +1696,7 @@ class Scan(PureOp): ...@@ -1699,7 +1696,7 @@ class Scan(PureOp):
scan_outs.append(tuple(shp)) scan_outs.append(tuple(shp))
scan_outs += [x for x in scan_outs += [x for x in
input_shapes[offset:offset + self.n_shared_outs]] input_shapes[offset:offset + self.n_shared_outs]]
# if we are dealing with a repeat-until, then we do not know the # if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i # leading dimension so we replace it for every entry with Shape_i
if self.as_while: if self.as_while:
...@@ -1763,7 +1760,7 @@ class Scan(PureOp): ...@@ -1763,7 +1760,7 @@ class Scan(PureOp):
j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx] j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx]
if j_inp_idx != -1: if j_inp_idx != -1:
if connection_pattern[j_inp_idx][iidx] == True: if connection_pattern[j_inp_idx][iidx] == True:
for k in xrange(len(connection_pattern)): for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]: if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True connection_pattern[k][iidx] = True
...@@ -1887,18 +1884,18 @@ class Scan(PureOp): ...@@ -1887,18 +1884,18 @@ class Scan(PureOp):
# With the global mapping inferred, the individual mappings # With the global mapping inferred, the individual mappings
# can be produced # can be produced
mappings = {"outer_inp_from_outer_out" : {}, mappings = {"outer_inp_from_outer_out": {},
"inner_inp_from_outer_out" : {}, "inner_inp_from_outer_out": {},
"inner_out_from_outer_out" : {}, "inner_out_from_outer_out": {},
"inner_inp_from_outer_inp" : {}, "inner_inp_from_outer_inp": {},
"inner_out_from_outer_inp" : {}, "inner_out_from_outer_inp": {},
"outer_out_from_outer_inp" : {}, "outer_out_from_outer_inp": {},
"outer_inp_from_inner_inp" : {}, "outer_inp_from_inner_inp": {},
"inner_out_from_inner_inp" : {}, "inner_out_from_inner_inp": {},
"outer_out_from_inner_inp" : {}, "outer_out_from_inner_inp": {},
"outer_inp_from_inner_out" : {}, "outer_inp_from_inner_out": {},
"inner_inp_from_inner_out" : {}, "inner_inp_from_inner_out": {},
"outer_out_from_inner_out" : {}} "outer_out_from_inner_out": {}}
for (oinp, iinp, iout, oout) in izip(outer_input_indices, for (oinp, iinp, iout, oout) in izip(outer_input_indices,
inner_input_indices, inner_input_indices,
...@@ -1944,7 +1941,7 @@ class Scan(PureOp): ...@@ -1944,7 +1941,7 @@ class Scan(PureOp):
grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1 grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1
elif self.n_mit_sot > 0: elif self.n_mit_sot > 0:
grad_steps = self.outer_mitsot_outs(outs)[0].shape[0] +\ grad_steps = self.outer_mitsot_outs(outs)[0].shape[0] +\
self.mintaps[self.n_mit_mot] self.mintaps[self.n_mit_mot]
else: else:
grad_steps = inputs[0] grad_steps = inputs[0]
...@@ -2031,14 +2028,13 @@ class Scan(PureOp): ...@@ -2031,14 +2028,13 @@ class Scan(PureOp):
# to X. # to X.
known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()]) known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()])
grads = gradient.grad( grads = gradient.grad(cost=None,
cost=None, known_grads=known_grads,
known_grads=known_grads, wrt=wrt,
wrt=wrt, consider_constant=wrt,
consider_constant=wrt, disconnected_inputs='ignore',
disconnected_inputs='ignore', return_disconnected='None',
return_disconnected='None', null_gradients='return')
null_gradients='return')
for i in range(len(wrt)): for i in range(len(wrt)):
gmp[wrt[i]] = grads[i] gmp[wrt[i]] = grads[i]
...@@ -2098,7 +2094,6 @@ class Scan(PureOp): ...@@ -2098,7 +2094,6 @@ class Scan(PureOp):
dC_dXt = safe_new(dC_douts[idx][0]) dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt) dC_dXts.append(dC_dXt)
known_grads = OrderedDict() known_grads = OrderedDict()
dc_dxts_idx = 0 dc_dxts_idx = 0
for i in range(len(diff_outputs)): for i in range(len(diff_outputs)):
...@@ -2153,7 +2148,7 @@ class Scan(PureOp): ...@@ -2153,7 +2148,7 @@ class Scan(PureOp):
dC_dXtm1s.append(safe_new(dC_dXts[opos])) dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype: if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \ dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype) x.astype(dC_dXts[opos].dtype)
else: else:
dC_dXtm1s.append(safe_new(x)) dC_dXtm1s.append(safe_new(x))
...@@ -2180,7 +2175,7 @@ class Scan(PureOp): ...@@ -2180,7 +2175,7 @@ class Scan(PureOp):
seq = outs[idx] seq = outs[idx]
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
if outmaxtap - k != 0: if outmaxtap - k != 0:
nw_seq = seq[k - mintap: -(outmaxtap-k)][::-1] nw_seq = seq[k - mintap: -(outmaxtap - k)][::-1]
else: else:
nw_seq = seq[k - mintap:][::-1] nw_seq = seq[k - mintap:][::-1]
outer_inp_seqs.append(nw_seq) outer_inp_seqs.append(nw_seq)
...@@ -2288,7 +2283,6 @@ class Scan(PureOp): ...@@ -2288,7 +2283,6 @@ class Scan(PureOp):
new_inner_out_mitmot = theano.clone(new_inner_out_mitmot, new_inner_out_mitmot = theano.clone(new_inner_out_mitmot,
replace=[(to_replace, replacement)]) replace=[(to_replace, replacement)])
inner_out_mitmot.append(new_inner_out_mitmot) inner_out_mitmot.append(new_inner_out_mitmot)
if not disconnected_dC_dinps_t[ins_pos]: if not disconnected_dC_dinps_t[ins_pos]:
...@@ -2553,8 +2547,7 @@ class Scan(PureOp): ...@@ -2553,8 +2547,7 @@ class Scan(PureOp):
gradients.append(NullType(t)()) gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for p, (x, t) in enumerate( for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
zip(outputs[:end], type_outs[:end])):
if t == 'connected': if t == 'connected':
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == 'disconnected': elif t == 'disconnected':
...@@ -2587,12 +2580,11 @@ class Scan(PureOp): ...@@ -2587,12 +2580,11 @@ class Scan(PureOp):
start = len(gradients) start = len(gradients)
gradients += [DisconnectedType()() gradients += [DisconnectedType()()
for x in xrange(self.n_nit_sot)] for x in xrange(self.n_nit_sot)]
begin = end begin = end
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
for p, (x, t) in enumerate( for p, (x, t) in enumerate(zip(outputs[begin:end], type_outs[begin:end])):
zip(outputs[begin:end], type_outs[begin:end])):
if t == 'connected': if t == 'connected':
gradients.append(x[-1]) gradients.append(x[-1])
elif t == 'disconnected': elif t == 'disconnected':
...@@ -2629,7 +2621,7 @@ class Scan(PureOp): ...@@ -2629,7 +2621,7 @@ class Scan(PureOp):
self.outputs, '_rop') self.outputs, '_rop')
self_inputs = rval[0] self_inputs = rval[0]
rop_of_inputs = rval[0][:self.n_seqs + self.n_outs] + \ rop_of_inputs = rval[0][:self.n_seqs + self.n_outs] + \
rval[0][self.n_seqs + self.n_outs + self.n_shared_outs:] rval[0][self.n_seqs + self.n_outs + self.n_shared_outs:]
self_outputs = rval[1] self_outputs = rval[1]
# Step 1. Compute the R_op of the inner function # Step 1. Compute the R_op of the inner function
inner_eval_points = [scan_utils.safe_new(x, '_evalpoint') inner_eval_points = [scan_utils.safe_new(x, '_evalpoint')
...@@ -2640,8 +2632,7 @@ class Scan(PureOp): ...@@ -2640,8 +2632,7 @@ class Scan(PureOp):
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
if self.info['n_shared_outs'] > 0: if self.info['n_shared_outs'] > 0:
rop_self_outputs = rop_self_outputs[:-self.info['n_shared_outs']] rop_self_outputs = rop_self_outputs[:-self.info['n_shared_outs']]
rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs, rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
inner_eval_points)
if type(rop_outs) not in (list, tuple): if type(rop_outs) not in (list, tuple):
rop_outs = [rop_outs] rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan # Step 2. Figure out what corresponds to what in the scan
...@@ -2721,8 +2712,8 @@ class Scan(PureOp): ...@@ -2721,8 +2712,8 @@ class Scan(PureOp):
e = e + self.n_mit_sot e = e + self.n_mit_sot
ib = ie ib = ie
ie = ie + int(numpy.sum([len(x) for x in ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot:\ self.tap_array[self.n_mit_mot: \
self.n_mit_mot + self.n_mit_sot]])) self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = [] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]): for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None: if evp is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论