提交 34d8a4de authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixing PEP8 in several places

Note that I only went through the files and fix whatever the PEP8 script for vim reported as incorrect, but I had not actually paid attention to the text, so there might still be inconsistencies (like for example strange variable names, etc.). Also only scan_opt.py is done completely, scan_op.py still has a bunch of PEP8 inconsistency. After the fix I only run the tests in scan_module/tests.py and they passed so I assume I had not introduced any bug in my PEP8 fixes.
上级 1d7e3679
......@@ -5,10 +5,10 @@ See scan.py for details on scan
"""
__docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin " )
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
......@@ -39,15 +39,11 @@ _logger = logging.getLogger('theano.scan_module.scan_op')
class Scan(PureOp):
#
# OLD DOCUMENTATION CAN BE FOUND NEAR REVISION 2581
#
def __init__( self
, inputs
, outputs
, info
, typeConstructor = None
def __init__(self,
inputs,
outputs,
info,
typeConstructor=None,
):
"""
:param inputs: inputs of the inner function of scan
......@@ -56,7 +52,7 @@ class Scan(PureOp):
the scan op.
"""
# adding properties into self
self.inputs = inputs
self.inputs = inputs
self.outputs = outputs
self.__dict__.update(info)
# I keep a version of info in self, to use in __eq__ and __hash__,
......@@ -70,15 +66,16 @@ class Scan(PureOp):
jdx = 0
if typeConstructor is None:
typeConstructor = lambda broadcastable, dtype: TensorType(
broadcastable = broadcastable, dtype = dtype)
broadcastable=broadcastable, dtype=dtype)
while idx < self.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per
# output sequence
o = outputs[idx]
o = outputs[idx]
self.output_types.append(
typeConstructor( broadcastable = (False,) + o.type.broadcastable
, dtype = o.type.dtype)
typeConstructor(
broadcastable=(False,) + o.type.broadcastable,
dtype=o.type.dtype)
)
idx += len(self.mit_mot_out_slices[jdx])
jdx += 1
......@@ -88,32 +85,32 @@ class Scan(PureOp):
for o in outputs[idx:end]:
self.output_types.append(
typeConstructor(
broadcastable = (False,) + o.type.broadcastable
, dtype = o.type.dtype ))
broadcastable=(False,) + o.type.broadcastable,
dtype=o.type.dtype))
# shared outputs + possibly the ending condition
for o in outputs[end:]:
self.output_types.append( o.type )
self.output_types.append(o.type)
if self.as_while:
self.output_types = self.output_types[:-1]
self.destroy_map = {}
if hasattr(self,'inplace') and self.inplace:
if hasattr(self, 'inplace') and self.inplace:
for idx in xrange(self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot ):
self.n_sit_sot):
self.destroy_map[idx] = [idx + 1 + self.n_seqs]
mode_instance = compile.mode.get_mode(self.mode)
# if the default mode is used, and that mode is ProfileMode
# then we need to copy the mode otherwise the time for a given
# op will be counted multiple times
if ( self.mode is None and
isinstance(mode_instance, compile.profilemode.ProfileMode) ):
if (self.mode is None and
isinstance(mode_instance, compile.profilemode.ProfileMode)):
mode_instance = compile.profilemode.ProfileMode(
optimizer = mode_instance.provided_optimizer
, linker = mode_instance.provided_linker )
compile.profilemode.prof_mode_instance_to_print.append(mode_instance)
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker)
compile.profilemode.prof_mode_instance_to_print.append(
mode_instance)
self.mode_instance = mode_instance
if self.name:
self.mode_instance.message = self.name + " sub profile"
......@@ -122,7 +119,7 @@ class Scan(PureOp):
else:
self.mode_instance = mode_instance
if not hasattr(self,'name') or self.name is None:
if not hasattr(self, 'name') or self.name is None:
self.name = 'scan_fn'
# to have a fair __eq__ comparison later on, we update the info with
# the actual mode used to compile the function and the name of the
......@@ -130,27 +127,26 @@ class Scan(PureOp):
self.info['name'] = self.name
# Pre-computing some values to speed up perform
self.mintaps = [ numpy.min(x) for x in self.tap_array]
self.mintaps += [ 0 for x in xrange(self.n_nit_sot) ]
self.seqs_arg_offset = 1+self.n_seqs
self.shared_arg_offset = ( self.seqs_arg_offset
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot )
self.nit_sot_arg_offset = ( self.shared_arg_offset +
self.n_shared_outs )
self.mintaps = [numpy.min(x) for x in self.tap_array]
self.mintaps += [0 for x in xrange(self.n_nit_sot)]
self.seqs_arg_offset = 1 + self.n_seqs
self.shared_arg_offset = (self.seqs_arg_offset +
self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot)
self.nit_sot_arg_offset = (self.shared_arg_offset +
self.n_shared_outs)
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
if not self.info['gpu']:
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs)
local_env = gof.Env(tmp_in, tmp_out)
self._cmodule_key = gof.CLinker.cmodule_key_(local_env,[])
self._cmodule_key = gof.CLinker.cmodule_key_(local_env, [])
self._hash_inner_graph = hash(self._cmodule_key)
else:
self._hash_inner_graph = self.info['gpu_hash']
def make_node(self, *inputs):
assert numpy.all(isinstance(i, gof.Variable) for i in inputs)
# assert dtype is consistent
......@@ -173,23 +169,23 @@ class Scan(PureOp):
# Flags that indicate which inputs are vectors
self.vector_seqs = [ seq.ndim == 1 for seq in
inputs[1:1+self.n_seqs ] ]
self.vector_outs = [ arg.ndim ==1 for arg in
inputs[1+self.n_seqs: (1+self.n_seqs +
self.n_outs)] ]
self.vector_outs += [ False]*self.n_nit_sot
self.vector_seqs = [seq.ndim == 1 for seq in
inputs[1:1 + self.n_seqs]]
self.vector_outs = [arg.ndim == 1 for arg in
inputs[1 + self.n_seqs: (1 + self.n_seqs +
self.n_outs)]]
self.vector_outs += [False] * self.n_nit_sot
# Check if input sequences and variables representing a slice of
# them have the same dtype
for idx in xrange(self.n_seqs):
if inputs[1+idx].dtype != self.inputs[idx].dtype:
raise ValueError(err_msg1%( 'sequence'
, str(inputs[1+idx])
, idx
, inputs[1+idx].dtype
, str(self.inputs[idx])
, self.inputs[idx].dtype) )
if inputs[1 + idx].dtype != self.inputs[idx].dtype:
raise ValueError(err_msg1 % ('sequence',
str(inputs[1 + idx]),
idx,
inputs[1 + idx].dtype,
str(self.inputs[idx]),
self.inputs[idx].dtype))
# Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output
......@@ -198,73 +194,73 @@ class Scan(PureOp):
# Maybe checking that ndim fits would be good as well !?
index_i = self.n_seqs
index_o = 0
index = 1+self.n_seqs
start = index
end = index + self.n_mit_mot
index = 1 + self.n_seqs
start = index
end = index + self.n_mit_mot
while index < end:
for k in self.tap_array[index-start]:
for k in self.tap_array[index - start]:
if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1%( 'initial state (outputs_info'
' in scan nomenclature) '
, str(inputs[index])
, index
, inputs[index].dtype
, str(self.inputs[index_i])
, self.inputs[index_i].dtype) )
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(inputs[index]),
index,
inputs[index].dtype,
str(self.inputs[index_i]),
self.inputs[index_i].dtype))
index_i += 1
for k in self.mit_mot_out_slices[index-start]:
for k in self.mit_mot_out_slices[index - start]:
if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2%( str(inputs[index])
, index
, inputs[index].dtype
, self.outputs[index_o].dtype) )
raise ValueError(err_msg2 % (str(inputs[index]),
index,
inputs[index].dtype,
self.outputs[index_o].dtype))
index_o += 1
index += 1
# Same checks as above but for outputs of type mit_sot and sit_sot
end += self.n_mit_sot + self.n_sit_sot
while index < end:
for k in self.tap_array[index-start]:
for k in self.tap_array[index - start]:
if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1%( 'Initial state'
, str(inputs[index])
, index
, inputs[index].dtype
, str(self.inputs[index_i])
, self.inputs[index_i].dtype) )
raise ValueError(err_msg1 % ('Initial state',
str(inputs[index]),
index,
inputs[index].dtype,
str(self.inputs[index_i]),
self.inputs[index_i].dtype))
index_i += 1
if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2%( str(inputs[index])
, index
, inputs[index].dtype
, self.outputs[index_o].dtype) )
raise ValueError(err_msg2 % (str(inputs[index]),
index,
inputs[index].dtype,
self.outputs[index_o].dtype))
index_o += 1
index += 1
index += 1
# Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?!
end += self.n_shared_outs
end += self.n_shared_outs
index_o += self.n_nit_sot
while index < end:
if (hasattr(inputs[index],'dtype') and
if (hasattr(inputs[index], 'dtype') and
inputs[index].dtype != self.outputs[index_o].dtype):
raise ValueError(err_msg2%( str(inputs[index])
, index
, inputs[index].dtype
, self.outputs[index_o].dtype) )
index += 1
raise ValueError(err_msg2 % (str(inputs[index]),
index,
inputs[index].dtype,
self.outputs[index_o].dtype))
index += 1
index_o += 1
for x in inputs[index:index+ self.n_nit_sot]:
for x in inputs[index:index + self.n_nit_sot]:
# For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization
if (str(x.dtype)[:3] not in ('uin','int') or
if (str(x.dtype)[:3] not in ('uin', 'int') or
x.ndim != 0):
raise ValueError('For output %d you need to provide a '
'scalar int !',x)
'scalar int !', x)
apply_node = Apply(self
, inputs
, [t() for t in self.output_types])
apply_node = Apply(self,
inputs,
[t() for t in self.output_types])
return apply_node
def __eq__(self, other):
......@@ -284,7 +280,7 @@ class Scan(PureOp):
# check. Namely, do the internal graph represent same
# computations
for self_in, other_in in zip(self.inputs, other.inputs):
if self_in.type != other_in.type :
if self_in.type != other_in.type:
return False
if not scan_utils.equal_computations(self.outputs,
......@@ -308,21 +304,19 @@ class Scan(PureOp):
else:
name = 'for'
if self.inplace :
aux_txt = '%s{inplace,%s,%s}'%(name, gpu_str, str(self.name))
if self.inplace:
aux_txt = '%s{inplace,%s,%s}' % (name, gpu_str, str(self.name))
else:
aux_txt = '%s{%s,%s}'%(name,gpu_str, str(self.name))
aux_txt = '%s{%s,%s}' % (name, gpu_str, str(self.name))
return aux_txt
def __hash__(self):
return ( hash(type(self)) ^
return (hash(type(self)) ^
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self._hash_inner_graph ^
scan_utils.hash_listsDictsTuples(self.info) )
scan_utils.hash_listsDictsTuples(self.info))
def make_thunk(self, node, storage_map, compute_map, no_recycling):
"""
......@@ -348,7 +342,6 @@ class Scan(PureOp):
# Setting up all my variables in what I believe is a more Cython
# friendly form
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
......@@ -357,64 +350,65 @@ class Scan(PureOp):
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# scan is done
slices = ( self.n_mit_mot_outs +
slices = (self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot )
wrapped_inputs = [Param(x, borrow=True) for x in self.inputs ]
self.n_nit_sot)
wrapped_inputs = [Param(x, borrow=True) for x in self.inputs]
wrapped_outputs = [Out(x, borrow=True) for x in
self.outputs[:slices] ]
self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:]
profile = None
if (theano.config.profile or (isinstance(self.profile, (basestring, bool, int))
if (theano.config.profile or
(isinstance(self.profile, (basestring, bool, int))
and self.profile)):
if isinstance(self.profile, basestring):
profile = ScanProfileStats(name = self.profile)
profile = ScanProfileStats(name=self.profile)
else:
profile = ScanProfileStats(name = self.name)
profile = ScanProfileStats(name=self.name)
elif self.profile:
profile = self.profile
self.fn = function(wrapped_inputs,
wrapped_outputs,
mode = self.mode_instance,
name = self.name,
profile = profile)
mode=self.mode_instance,
name=self.name,
profile=profile)
try:
cython_mintaps = numpy.asarray(self.mintaps, dtype = 'int32')
raise ImportError
cython_mintaps = numpy.asarray(self.mintaps, dtype='int32')
cython_tap_array_len = \
numpy.asarray([ len(x) for x in self.tap_array],
numpy.asarray([len(x) for x in self.tap_array],
dtype='int32')
if len(self.tap_array) == 0:
d1 = 0
else:
d1 = numpy.max(cython_tap_array_len)
d0 = len(self.tap_array)
cython_tap_array = numpy.zeros((d0,d1), dtype='int32')
cython_tap_array = numpy.zeros((d0, d1), dtype='int32')
for _d0 in range(d0):
for _d1 in range(cython_tap_array_len[_d0]):
cython_tap_array[_d0,_d1] = self.tap_array[_d0][_d1]
cython_tap_array[_d0, _d1] = self.tap_array[_d0][_d1]
cython_mit_mot_out_nslices = \
numpy.asarray([ len(x) for x in self.mit_mot_out_slices],
numpy.asarray([len(x) for x in self.mit_mot_out_slices],
dtype='int32')
if len(self.mit_mot_out_slices) == 0:
d1 = 0
else:
d1 = numpy.max(cython_mit_mot_out_nslices)
d0 = len(self.mit_mot_out_slices)
cython_mit_mot_out_slices = numpy.zeros((d0,d1),
cython_mit_mot_out_slices = numpy.zeros((d0, d1),
dtype='int32')
for _d0 in range(d0):
for _d1 in range(cython_mit_mot_out_nslices[_d0]):
cython_mit_mot_out_slices[_d0,_d1] = \
cython_mit_mot_out_slices[_d0, _d1] = \
self.mit_mot_out_slices[_d0][_d1]
vector_seqs = [ seq.ndim == 1 for seq in
node.inputs[1:1+self.n_seqs ] ]
vector_outs = [ arg.ndim ==1 for arg in
node.inputs[1+self.n_seqs: (1+self.n_seqs +
self.n_outs)] ]
vector_outs += [ False]*self.n_nit_sot
vector_seqs = [seq.ndim == 1 for seq in
node.inputs[1:1 + self.n_seqs]]
vector_outs = [arg.ndim == 1 for arg in
node.inputs[1 + self.n_seqs:
(1 + self.n_seqs + self.n_outs)]]
vector_outs += [False] * self.n_nit_sot
cython_vector_seqs = numpy.asarray(self.vector_seqs,
dtype='int32')
......@@ -448,6 +442,7 @@ class Scan(PureOp):
except ImportError:
p = self.execute
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
......@@ -463,14 +458,14 @@ class Scan(PureOp):
return self.inputs[:self.n_seqs]
def outer_seqs(self, node):
return node.inputs[1:1+self.n_seqs]
return node.inputs[1:1 + self.n_seqs]
def inner_mitmot(self):
n_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
return self.inputs[self.n_seqs: self.n_seqs + n_taps]
def outer_mitmot(self, node):
return node.inputs[1+self.n_seqs:1+self.n_seqs+self.n_mit_mot]
return node.inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
......@@ -490,80 +485,80 @@ class Scan(PureOp):
ntaps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
return self.inputs[self.n_seqs+n_mitmot_taps:
self.n_seqs+ntaps_upto_sit_sot]
return self.inputs[self.n_seqs + n_mitmot_taps:
self.n_seqs + ntaps_upto_sit_sot]
def outer_mitsot(self, node):
offset = 1+self.n_seqs+self.n_mit_mot
return node.inputs[offset:offset+self.n_mit_sot]
offset = 1 + self.n_seqs + self.n_mit_mot
return node.inputs[offset:offset + self.n_mit_sot]
def inner_mitsot_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[n_taps:n_taps+self.n_mit_sot]
return self.outputs[n_taps:n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, node):
return node.outputs[self.n_mit_mot:self.n_mit_mot+self.n_mit_sot]
return node.outputs[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self):
return self.tap_array[self.n_mit_mot:self.n_mit_mot+self.n_mit_sot]
return self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]
def inner_sitsot(self):
n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot
return self.inputs[offset:offset+self.n_sit_sot]
return self.inputs[offset:offset + self.n_sit_sot]
def outer_sitsot(self,node):
offset = 1+self.n_seqs+self.n_mit_mot + self.n_mit_sot
return node.inputs[offset:offset+self.n_sit_sot]
def outer_sitsot(self, node):
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return node.inputs[offset:offset + self.n_sit_sot]
def inner_sitsot_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps
return self.outputs[offset:offset+self.n_sit_sot]
return self.outputs[offset:offset + self.n_sit_sot]
def outer_sitsot_outs(self, node):
offset = self.n_mit_mot + self.n_mit_sot
return node.outputs[offset:offset+self.n_sit_sot]
return node.outputs[offset:offset + self.n_sit_sot]
def outer_nitsot(self, node):
offset = (1 + self.n_seqs+self.n_mit_mot + self.n_mit_sot +
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_shared_outs)
return node.inputs[offset:offset+self.n_nit_sot]
return node.inputs[offset:offset + self.n_nit_sot]
def inner_nitsot_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot
return self.outputs[offset:offset+self.n_nit_sot]
return self.outputs[offset:offset + self.n_nit_sot]
def outer_nitsot_outs(self, node):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot)
return node.outputs[offset:offset+self.n_nit_sot]
return node.outputs[offset:offset + self.n_nit_sot]
def inner_shared(self):
n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot
return self.inputs[offset:offset+self.n_shared_outs]
return self.inputs[offset:offset + self.n_shared_outs]
def outer_shared(self, node):
offset = (1 + self.n_seqs+self.n_mit_mot + self.n_mit_sot +
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot)
return node.inputs[offset:offset+self.n_shared_outs]
return node.inputs[offset:offset + self.n_shared_outs]
def inner_shared_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot
return self.outputs[offset:offset+self.n_shared_outs]
return self.outputs[offset:offset + self.n_shared_outs]
def outer_shared_outs(self, node):
offset = ( self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot)
return node.outputs[offset:offset+self.n_shared_outs]
return node.outputs[offset:offset + self.n_shared_outs]
def inner_non_seqs(self):
n_taps_upto_sit_sot = sum(len(x) for x in
......@@ -574,12 +569,11 @@ class Scan(PureOp):
return self.inputs[offset:]
def outer_non_seqs(self, node):
offset = ( 1+ self.n_seqs + self.n_mit_mot + self.n_mit_sot +
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return node.inputs[offset:]
def execute( self, node, args, outs):
def execute(self, node, args, outs):
"""
The args are packed like this:
......@@ -607,7 +601,7 @@ class Scan(PureOp):
# negative flip sequences around, and make n_steps positive
t0_call = time.time()
t_fn = 0
n_steps = args[0]
n_steps = args[0]
seqs = []
if n_steps < 0:
n_steps = abs(n_steps)
......@@ -616,7 +610,7 @@ class Scan(PureOp):
raise ValueError(('Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps,
node.inputs[1+idx],
node.inputs[1 + idx],
seq.shape)
seqs.append(seq[::-1])
else:
......@@ -625,35 +619,37 @@ class Scan(PureOp):
raise ValueError(('Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps,
node.inputs[1+idx],
node.inputs[1 + idx],
seq.shape)
seqs.append(seq)
# 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containting the length of each output
# pos -- map containing the current position of each output
# pos -- map containing the current position of each
# output
store_steps = [ arg.shape[0] for arg
store_steps = [arg.shape[0] for arg
in args[self.seqs_arg_offset:
self.shared_arg_offset] ]
store_steps += [ arg for arg in
self.shared_arg_offset]]
store_steps += [arg for arg in
args[self.nit_sot_arg_offset:
self.nit_sot_arg_offset+self.n_nit_sot]
self.nit_sot_arg_offset + self.n_nit_sot]
]
pos = [ (-self.mintaps[idx])%store_steps[idx] for idx
in xrange(self.n_outs+self.n_nit_sot)]
pos = [(-self.mintaps[idx]) % store_steps[idx] for idx
in xrange(self.n_outs + self.n_nit_sot)]
# 2.1 Create storage space for outputs
for idx in xrange(self.n_outs):
if self.inplace:
# ^ Case 1. Outputs should be computed inplace of their
# initial state
outs[idx][0] = args[self.seqs_arg_offset + idx ]
elif ( outs[idx][0] is not None and
outs[idx][0].shape[1:] == args[self.seqs_arg_offset + idx].shape[1:]
and outs[idx][0].shape[0] >= store_steps[idx] ):
outs[idx][0] = args[self.seqs_arg_offset + idx]
elif (outs[idx][0] is not None and
outs[idx][0].shape[1:] == args[self.seqs_arg_offset +
idx].shape[1:]
and outs[idx][0].shape[0] >= store_steps[idx]):
# Put in the values of the initial state
outs[idx][0] = outs[idx][0][:store_steps[idx]]
outs[idx][0] = outs[idx][0][:store_steps[idx]]
if idx > self.n_mit_mot:
l = - self.mintaps[idx]
outs[idx][0][:l] = args[self.seqs_arg_offset + idx][:l]
......@@ -662,28 +658,28 @@ class Scan(PureOp):
else:
outs[idx][0] = args[self.seqs_arg_offset + idx].copy()
offset = self.nit_sot_arg_offset + self.n_nit_sot
other_args = args[offset:]
input_storage = self.fn.input_storage
output_storage = self.fn.output_storage
fn = self.fn.fn
offset = ( self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) +
offset = (self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) +
self.n_shared_outs)
for idx in xrange(len(other_args)):
input_storage[idx+offset].storage[0] = other_args[idx]
input_storage[idx + offset].storage[0] = other_args[idx]
i = 0
cond = True
############## THE MAIN LOOP #########################
#for i in xrange(n_steps):
while (i< n_steps) and cond:
while (i < n_steps) and cond:
# sequences over which scan iterates
# 3. collect input slices
for idx in xrange(self.n_seqs):
if self.vector_seqs[idx]:
input_storage[idx].storage[0] = seqs[idx][i:i+1].reshape(())
input_storage[idx].storage[0] = \
seqs[idx][i:i + 1].reshape(())
else:
input_storage[idx].storage[0] = seqs[idx][i]
......@@ -691,26 +687,25 @@ class Scan(PureOp):
for idx in xrange(self.n_outs):
if self.vector_outs[idx]:
for tap in self.tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx]
_idx = (pos[idx] + tap) % store_steps[idx]
input_storage[offset].storage[0] =\
outs[idx][0][_idx:_idx+1].reshape(())
outs[idx][0][_idx:_idx + 1].reshape(())
offset += 1
else:
for tap in self.tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx]
_idx = (pos[idx] + tap) % store_steps[idx]
input_storage[offset].storage[0] = outs[idx][0][_idx]
offset += 1
a_offset = self.shared_arg_offset
o_offset = self.n_outs + self.n_nit_sot
if i == 0:
for j in xrange(self.n_shared_outs):
input_storage[offset].storage[0] = args[a_offset+j]
input_storage[offset].storage[0] = args[a_offset + j]
offset += 1
else:
for j in xrange(self.n_shared_outs):
input_storage[offset].storage[0] = outs[o_offset+j][0]
input_storage[offset].storage[0] = outs[o_offset + j][0]
offset += 1
# 4. collecting slices where the output should be stored
......@@ -718,23 +713,24 @@ class Scan(PureOp):
output_storage[idx].storage[0] = None
offset = self.n_mit_mot_outs
if i !=0 and self.n_nit_sot >0:
if i != 0 and self.n_nit_sot > 0:
for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot):
if ( store_steps[idx+self.n_mit_mot] == 1 or
self.vector_outs[idx+self.n_mit_mot]):
output_storage[idx+offset].storage[0] = None
if (store_steps[idx + self.n_mit_mot] == 1 or
self.vector_outs[idx + self.n_mit_mot]):
output_storage[idx + offset].storage[0] = None
else:
output_storage[idx+offset].storage[0] =\
outs[idx+self.n_mit_mot][0][pos[idx+self.n_mit_mot]]
_pos0 = idx + self.n_mit_mot
output_storage[idx + offset].storage[0] =\
outs[_pos0][0][pos[_pos0]]
else:
for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot):
output_storage[idx+offset].storage[0] = None
output_storage[idx + offset].storage[0] = None
offset += self.n_outs+self.n_nit_sot - self.n_mit_mot
offset += self.n_outs + self.n_nit_sot - self.n_mit_mot
for idx in xrange(self.n_shared_outs):
output_storage[idx+offset].storage[0] = None
output_storage[idx + offset].storage[0] = None
# If condition add it to the mix
if self.as_while:
pdx = offset + self.n_shared_outs
......@@ -762,97 +758,102 @@ class Scan(PureOp):
# 5.1 Copy over the values for mit_mot outputs
for j in xrange(self.n_mit_mot):
for k in self.mit_mot_out_slices[j]:
outs[j][0][k+pos[j]] = output_storage[offset_out].storage[0]
outs[j][0][k + pos[j]] = \
output_storage[offset_out].storage[0]
offset_out += 1
# 5.2 Copy over the values for mit_sot/sit_sot outputs
begin = self.n_mit_mot
end = self.n_outs
end = self.n_outs
offset_out -= self.n_mit_mot
for j in xrange(begin, end):
if ( store_steps[j] == 1 or self.vector_outs[j] or
outs[j][0][pos[j]] is not output_storage[offset_out+j].storage[0]):
outs[j][0][pos[j]] = output_storage[offset_out+j].storage[0]
if (store_steps[j] == 1 or self.vector_outs[j] or
outs[j][0][pos[j]] is not
output_storage[offset_out + j].storage[0]):
outs[j][0][pos[j]] = \
output_storage[offset_out + j].storage[0]
# 5.3 Copy over the values for nit_sot outputs
begin = end
end += self.n_nit_sot
for j in xrange(begin,end):
begin = end
end += self.n_nit_sot
for j in xrange(begin, end):
if i == 0:
jout = j+offset_out
shape = (store_steps[j],) + output_storage[jout].storage[0].shape
jout = j + offset_out
shape = (store_steps[j],) + \
output_storage[jout].storage[0].shape
if len(output_storage[jout].storage[0].shape) == 0:
self.vector_outs[j] = True
dtype = output_storage[jout].storage[0].dtype
if (outs[j][0] is None or
outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype ):
outs[j][0].dtype != dtype):
if self.gpu:
outs[j][0] = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape)
_cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
outs[j][0] = _cuda.zeros(shape)
else:
outs[j][0] = numpy.zeros(shape, dtype)
elif outs[j][0].shape[0] != store_steps[j]:
outs[j][0] = outs[j][0][:store_steps[j]]
outs[j][0][pos[j]] = output_storage[jout].storage[0]
elif (store_steps[j] == 1 or self.vector_outs[j] or
outs[j][0][pos[j]] is not output_storage[j+offset_out].storage[0]):
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0]
outs[j][0][pos[j]] is not
output_storage[j + offset_out].storage[0]):
outs[j][0][pos[j]] = \
output_storage[j + offset_out].storage[0]
# 5.4 Copy over the values for outputs corresponding to shared
# variables
begin = end
end += self.n_shared_outs
for j in xrange(begin,end):
jout = j +offset_out
begin = end
end += self.n_shared_outs
for j in xrange(begin, end):
jout = j + offset_out
outs[j][0] = output_storage[jout].storage[0]
pos = [ (idx+1)%store for idx,store in
itertools.izip(pos, store_steps)
]
i = i+1
pos = [(idx + 1) % store for idx, store in
itertools.izip(pos, store_steps)]
i = i + 1
# 6. Check if you need to re-order output buffers
begin = self.n_mit_mot
end = self.n_outs + self.n_nit_sot
end = self.n_outs + self.n_nit_sot
for idx in xrange(begin, end):
min_tap = self.mintaps[idx]
if ( store_steps[idx] < i-self.mintaps[idx] and
pos[idx] < store_steps[idx] ):
if (store_steps[idx] < i - self.mintaps[idx] and
pos[idx] < store_steps[idx]):
pdx = pos[idx]
if pdx < store_steps[idx]//2 :
shape = (pdx,)+ outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance( outs[idx][0],
if pdx < store_steps[idx] // 2:
shape = (pdx,) + outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance(outs[idx][0],
cuda.CudaNdarray):
tmp = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape)
_cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
tmp = _cuda.zeros(shape)
else:
tmp = numpy.empty(shape)
tmp[:] = outs[idx][0][:pdx]
outs[idx][0][:store_steps[idx]-pdx] = outs[idx][0][pdx:]
outs[idx][0][store_steps[idx]-pdx:] = tmp
outs[idx][0][:store_steps[idx] - pdx] = outs[idx][0][pdx:]
outs[idx][0][store_steps[idx] - pdx:] = tmp
del tmp
else:
shape = (store_steps[idx]-pdx,) + outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance( outs[idx][0],
shape = (store_steps[idx] - pdx,) + outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance(outs[idx][0],
cuda.CudaNdarray):
tmp = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape)
_cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
tmp = _cuda.zeros(shape)
else:
tmp = numpy.empty(shape)
tmp[:] = outs[idx][0][pdx:]
outs[idx][0][store_steps[idx]-pdx:] = outs[idx][0][:pdx]
outs[idx][0][:store_steps[idx]-pdx] = tmp
outs[idx][0][store_steps[idx] - pdx:] = outs[idx][0][:pdx]
outs[idx][0][:store_steps[idx] - pdx] = tmp
del tmp
# This would normally happen only when doing truncated
# backpropagation through time. In such a scenarion Scan is
# expected to return 0 for all entries for which the gradient is
# not actually computed
elif store_steps[idx] > i - self.mintaps[idx]:
outs[idx][0][i-self.mintaps[idx]:] = 0
outs[idx][0][i - self.mintaps[idx]:] = 0
# This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output
# to have that length ( and not the maximal length possible)
......@@ -883,7 +884,7 @@ class Scan(PureOp):
profile.callcount += 1
profile.nbsteps += n_steps
profile.call_time += t_call
profile.vm_call_time += t_fn
profile.vm_call_time += t_fn
if hasattr(self.fn.fn, 'update_profile'):
self.fn.fn.update_profile(profile)
......@@ -896,7 +897,7 @@ class Scan(PureOp):
#self.fn.maker.mode.fn_time += t_fn
# Old Profile Mode */
self.t_call = t_call
self.t_fn = t_fn
self.t_fn = t_fn
### Infer Shape
def infer_shape(self, node, input_shapes):
......@@ -905,26 +906,27 @@ class Scan(PureOp):
# is the shape of self.inputs[i]
# sequences
seqs_shape = [ x[1:] for x in input_shapes[1:1+self.n_seqs] ]
seqs_shape = [x[1:] for x in input_shapes[1:1 + self.n_seqs]]
# mit_mot, mit_sot, sit_sot
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outs_shape = []
for idx in xrange(n_outs):
for k in self.tap_array[idx]:
outs_shape += [ input_shapes[idx+self.n_seqs+1][1:] ]
outs_shape += [input_shapes[idx + self.n_seqs + 1][1:]]
# shared_outs
offset = 1 + self.n_seqs + n_outs
for idx in xrange(self.n_shared_outs):
outs_shape += [ input_shapes[idx+offset] ]
outs_shape += [input_shapes[idx + offset]]
# non_sequences
offset += self.n_nit_sot + self.n_shared_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inputs)
# Non-sequences have a direct equivalent from self.inputs in node.inputs
# Non-sequences have a direct equivalent from self.inputs in
# node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):]
out_equivalent = {}
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
......@@ -934,22 +936,22 @@ class Scan(PureOp):
else:
self_outs = self.outputs
outs_shape = scan_utils.infer_shape(
outs = self_outs,
inputs = self.inputs,
input_shapes = inner_ins_shapes)
outs=self_outs,
inputs=self.inputs,
input_shapes=inner_ins_shapes)
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs.
# The shapes of node.inputs are valid.
validator = scan_utils.Validator(
valid = input_shapes,
invalid = self.inputs,
valid_equivalent = out_equivalent)
valid=input_shapes,
invalid=self.inputs,
valid_equivalent=out_equivalent)
offset = 1 + self.n_seqs
scan_outs = [x for x in input_shapes[offset:offset+n_outs]]
scan_outs = [x for x in input_shapes[offset:offset + n_outs]]
offset += n_outs
for x in xrange(self.n_nit_sot):
out_shape_x = outs_shape[n_outs+x]
out_shape_x = outs_shape[n_outs + x]
if out_shape_x is None:
# This output is not a tensor, and has no shape
scan_outs.append(None)
......@@ -957,10 +959,10 @@ class Scan(PureOp):
# We need to make sure that we can compute the shapes from
# node.inputs, and constants, without using the variables
# in the inner function.
r = node.outputs[n_outs+x]
r = node.outputs[n_outs + x]
assert r.ndim == 1 + len(out_shape_x)
shp = [node.inputs[offset+self.n_shared_outs+x]]
for i, shp_i in zip(xrange(1,r.ndim), out_shape_x):
shp = [node.inputs[offset + self.n_shared_outs + x]]
for i, shp_i in zip(xrange(1, r.ndim), out_shape_x):
# Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates
# whether variable is shp_i (if True), or an valid
......@@ -976,34 +978,32 @@ class Scan(PureOp):
shp.append(v_shp_i[0])
scan_outs.append(tuple(shp))
scan_outs += [ x for x in
input_shapes[offset:offset+self.n_shared_outs] ]
scan_outs += [x for x in
input_shapes[offset:offset + self.n_shared_outs]]
return scan_outs
### GRAD FUNCTION
def grad(self, args, g_outs):
# 1. forward pass - get the outputs after applying scan
scan_outputs = self(*args)
# 2. make sure they are given as a list
if not( type(scan_outputs) in (list,tuple)):
if not(type(scan_outputs) in (list, tuple)):
scan_outputs = [scan_outputs]
# 3. un-group / unzip the inputs
# Note ! We don't want to use the actual same variable as the ones
# used by the original scan, rather create clones of them
rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs,'_grad')
self_inputs = rval[0]
self.outputs, '_grad')
self_inputs = rval[0]
self_outputs = rval[1]
seqs = self_inputs[:self.n_seqs]
seqs = self_inputs[:self.n_seqs]
offset = self.n_seqs
n_ins_mit_mot = numpy.sum([0] + [ len(self.tap_array[x]) for x
in xrange(self.n_mit_mot) ])
outs_mit_mot = self_inputs[offset:offset+n_ins_mit_mot]
offset = self.n_seqs
n_ins_mit_mot = numpy.sum([0] + [len(self.tap_array[x]) for x
in xrange(self.n_mit_mot)])
outs_mit_mot = self_inputs[offset:offset + n_ins_mit_mot]
offset += n_ins_mit_mot
n_ins_mit_sot = numpy.sum([0] + [ len(self.tap_array[x]) for x
......@@ -1082,6 +1082,11 @@ class Scan(PureOp):
# 7.3. compute gradients of the inputs given one output
for dx, out in enumerate(clean_outputs):
inner_g_out = safe_new(out)
###
#### I need to clip the gradient HERE !!
if g_outs_no_shared[dx]:
g_out_slices.append(g_outs_no_shared[dx][0])
else:
......
......@@ -4,11 +4,11 @@ This module provides optimizations for scan
__docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin "
"Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
......@@ -32,16 +32,20 @@ from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_opt')
list_opt_slice = [ tensor.opt.local_abs_merge,
tensor.opt.local_mul_switch_sink,
tensor.opt.local_upcast_elemwise_constant_inputs,
tensor.opt.local_remove_switch_const_cond,
tensor.opt.constant_folding ]
list_opt_slice = [tensor.opt.local_abs_merge,
tensor.opt.local_mul_switch_sink,
tensor.opt.local_upcast_elemwise_constant_inputs,
tensor.opt.local_remove_switch_const_cond,
tensor.opt.constant_folding]
def warning(*msg):
_logger.warning('WARNING theano.scan: '+' '.join(msg))
_logger.warning('WARNING theano.scan: ' + ' '.join(msg))
def info(*msg):
_logger.info('INFO theano.scan: '+' '.join(msg))
_logger.info('INFO theano.scan: ' + ' '.join(msg))
@gof.local_optimizer([None])
def remove_constants_and_unused_inputs_scan(node):
......@@ -58,9 +62,9 @@ def remove_constants_and_unused_inputs_scan(node):
return False
op = node.op
# We only need to take care of sequences and other arguments
st = op.n_seqs
st = op.n_seqs
st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ]))
op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
st += op.n_sit_sot
st += op.n_shared_outs
op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs)
......@@ -70,17 +74,17 @@ def remove_constants_and_unused_inputs_scan(node):
out_stuff_inner = op_ins[op.n_seqs:st]
non_seqs = op_ins[st:]
st = ( op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs +1 )
st = (op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs + 1)
outer_non_seqs = node.inputs[st:]
out_stuff_outer = node.inputs[1+op.n_seqs:st]
out_stuff_outer = node.inputs[1 + op.n_seqs:st]
# To replace constants in the outer graph by clones in the inner graph
givens = {}
givens = {}
# All the inputs of the inner graph of the new scan
nw_inner = []
# Same for the outer graph, initialized w/ number of steps
......@@ -88,18 +92,18 @@ def remove_constants_and_unused_inputs_scan(node):
all_ins = gof.graph.inputs(op_outs)
for idx in xrange(op.n_seqs):
if (isinstance(node.inputs[idx+1], tensor.TensorConstant) and
node.inputs[idx+1].tag.unique_value is not None):
if (isinstance(node.inputs[idx + 1], tensor.TensorConstant) and
node.inputs[idx + 1].tag.unique_value is not None):
try:
# This works if input is a constant that has all entries
# equal
val = tensor.get_constant_value(node.inputs[idx+1])
givens[op_ins[idx]] = node.inputs[idx+1].clone()[0]
val = tensor.get_constant_value(node.inputs[idx + 1])
givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0]
except TypeError:
pass
elif op_ins[idx] in all_ins:
nw_inner += [op_ins[idx]]
nw_outer += [node.inputs[idx+1]]
nw_outer += [node.inputs[idx + 1]]
nw_n_seqs = len(nw_inner)
# Add outputs stuff
......@@ -114,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(node):
nw_outer += [nw_out]
if len(nw_inner) != len(op_ins):
op_outs = scan_utils.clone(op_outs, replace = givens)
op_outs = scan_utils.clone(op_outs, replace=givens)
nw_info = op.info.copy()
nw_info['n_seqs'] = nw_n_seqs
# DEBUG CHECK
......@@ -128,11 +132,12 @@ scan_seqopt = theano.gof.SequenceDB()
optdb.register('scan_seqopt', scan_seqopt, 1.9, 'fast_run', 'scan')
scan_seqopt.register('scanOp_remove_constants_and_unused_inputs',
opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees = True),
ignore_newtrees=True),
5,
'fast_run',
'scan')
# This is a global opt for historical reason
# It should be possible to change it to a local opt.
class PushOutNonSeqScan(gof.Optimizer):
......@@ -140,10 +145,9 @@ class PushOutNonSeqScan(gof.Optimizer):
def __init__(self):
gof.Optimizer.__init__(self)
def add_requirements(self,env):
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate())
def apply(self, env):
nodelist = [x for x in env.toposort() if isinstance(x.op,
scan_op.Scan)]
......@@ -152,34 +156,31 @@ class PushOutNonSeqScan(gof.Optimizer):
def process_node(self, env, node):
# this flag tells if there was any change during the last iterations
changed = True
changed = True
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_env = gof.Env(clean_inputs, clean_outputs)
max_iterations = 2*len(local_env.toposort()) + 3
max_iterations = 2 * len(local_env.toposort()) + 3
counts = 0
to_remove = []
to_replace = []
replace_with_in = []
to_remove = []
to_replace = []
replace_with_in = []
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
st = op.n_seqs
st = op.n_seqs
st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ]))
op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
st += op.n_sit_sot
st += op.n_shared_outs
non_seqs = clean_inputs[st:]
st = ( op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs +1 )
st = (op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs + 1)
outer_non_seqs = node.inputs[st:]
assert len(non_seqs) == len(outer_non_seqs)
while changed and counts < max_iterations:
......@@ -187,15 +188,15 @@ class PushOutNonSeqScan(gof.Optimizer):
changed = False
for nd in local_env.toposort():
if ( numpy.all([ (x in non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
if (numpy.all([(x in non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
not isinstance(nd.op,theano.compile.ViewOp) and
not isinstance(nd.op,theano.compile.DeepCopyOp) and
not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp) and
# and we didn't already looked at this node
not nd in to_remove
):
......@@ -206,49 +207,50 @@ class PushOutNonSeqScan(gof.Optimizer):
outside_ins = []
for x in nd.inputs:
if x in non_seqs:
outside_ins +=[ outer_non_seqs[non_seqs.index(x)]]
outside_ins += [outer_non_seqs[non_seqs.index(x)]]
elif x in to_replace:
outside_ins +=[replace_with_out[to_replace.index(x)]]
outside_ins += [
replace_with_out[to_replace.index(x)]]
elif isinstance(x, theano.Constant):
outside_ins +=[x.clone()]
outside_ins += [x.clone()]
else:
raise Exception(
('Error in the `scan_pushout_non_seq_operations`'
'. The optimization tries to move some '
'computation fron scan which is not allowed '
'to move. Report this on theano-users list'),x )
('Error in the `scan_pushout_non_seq_'
'operations`. The optimization tries '
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'), x)
nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements
for idx,y in enumerate(nd.outputs):
for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y,'_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
assert type(y) == type(nw_outer_node.outputs[idx])
replace_with_out += [nw_outer_node.outputs[idx]]
changed = True
if counts >= max_iterations:
raise Exception( ('Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!'))
raise Exception('Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!')
# We need to check all candidate replacements and choose those that
# make sense for us
# Step 1. which elements of `to_replace` are used by remaining
# components of the inner function
clean_to_replace = []
clean_replace_with_in = []
clean_to_replace = []
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [ nd for nd in local_env.toposort()
existent_nodes = [nd for nd in local_env.toposort()
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
for idx,out in enumerate(to_replace):
for idx, out in enumerate(to_replace):
if out in to_keep and out.owner not in existent_nodes:
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]]
if len(clean_to_replace) > 0:
......@@ -256,7 +258,7 @@ class PushOutNonSeqScan(gof.Optimizer):
givens = {}
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip( clean_to_replace,
for to_repl, repl_in, repl_out in zip(clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
......@@ -274,7 +276,7 @@ class PushOutNonSeqScan(gof.Optimizer):
nwScan = scan_op.Scan(op_ins, op_outs, op.info)
nw_node = nwScan.make_node(* (node.inputs + nw_outer))
env.replace_all_validate(zip(node.outputs, nw_node.outputs),
reason = 'scan_push_computation_out')
reason='scan_push_computation_out')
return True
elif to_keep == []:
# Nothing in the inner graph should be kept
......@@ -290,7 +292,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# We need to add one extra dimension to the outputs
env.replace_all_validate(replace_with.items(),
reason = 'scan_push_computation_out')
reason='scan_push_computation_out')
else:
return False
......@@ -306,17 +308,17 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops',
@gof.local_optimizer([None])
def scan_make_inplace(node):
op = node.op
if ( isinstance(op, scan_op.Scan) and
if (isinstance(op, scan_op.Scan) and
(not op.info['inplace']) and
(not op.info['gpu'])):
info = op.info.copy()
info['inplace'] = True
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1+op.n_seqs]
ls = op.outer_mitmot(node)
ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node)
ls += op.outer_mitsot(node)
ls += op.outer_sitsot(node)
ls_end = op.outer_shared(node)
ls_end = op.outer_shared(node)
ls_end += op.outer_nitsot(node)
ls_end += op.outer_non_seqs(node)
n_outs = len(ls)
......@@ -325,19 +327,18 @@ def scan_make_inplace(node):
ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end
new_op = scan_op.Scan( op.inputs
, op.outputs
, info)
new_op = scan_op.Scan(op.inputs,
op.outputs,
info)
return new_op.make_node(*inputs).outputs
return False
optdb.register( 'scanOp_make_inplace'
, opt.in2out(scan_make_inplace,ignore_newtrees=True)
, 75
, 'fast_run'
, 'inplace'
, 'scan')
optdb.register('scanOp_make_inplace',
opt.in2out(scan_make_inplace, ignore_newtrees=True),
75,
'fast_run',
'inplace',
'scan')
class ScanSaveMem(gof.Optimizer):
......@@ -345,24 +346,25 @@ class ScanSaveMem(gof.Optimizer):
def __init__(self):
gof.Optimizer.__init__(self)
def add_requirements(self,env):
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate())
def process_node(self, env, node):
# helpful functions
def select_min(x,y):
def select_min(x, y):
if x is None:
return y
if y is None:
return x
return tensor.minimum(x,y)
def select_max(x,y):
return tensor.minimum(x, y)
def select_max(x, y):
if x is None:
return y
if y is None:
return x
return tensor.maximum(x,y)
return tensor.maximum(x, y)
def sanitize(x):
if x is None:
......@@ -383,9 +385,9 @@ class ScanSaveMem(gof.Optimizer):
op = node.op
c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot
init_l = [ 0 for x in xrange(op.n_mit_mot)]
init_l += [ abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot:] ]
init_l += [ 0 for x in xrange(op.n_nit_sot)]
init_l = [0 for x in xrange(op.n_mit_mot)]
init_l += [abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot:]]
init_l += [0 for x in xrange(op.n_nit_sot)]
# 2. Check the clients of each output and see for how many steps
# does scan need to run
......@@ -408,13 +410,13 @@ class ScanSaveMem(gof.Optimizer):
# change the number of steps in that case. To do this we set
# global_nsteps to None which is seen as a flag that nothing needs
# to be done
if len(node.outputs) <= c_outs :
global_nsteps = {'real' :-1, 'sym': []}
if len(node.outputs) <= c_outs:
global_nsteps = {'real': -1, 'sym': []}
else:
global_nsteps = None
# Keeps track of the original slices that each client represent
slices = [ None for o in node.outputs]
slices = [None for o in node.outputs]
# A list for each output indicating how many intermediate values
# should be stored. If negative it means none of the intermediate
......@@ -425,31 +427,31 @@ class ScanSaveMem(gof.Optimizer):
# Note that for mit_mot outputs and shared outputs we can not change
# the number of intermediate steps stored without affecting the
# result of the op
store_steps = [ 0 for o in xrange(op.n_mit_mot)]
store_steps = [0 for o in xrange(op.n_mit_mot)]
store_steps += [-1 for o in node.outputs[op.n_mit_mot:c_outs]]
# Flag that says if an input has changed and we need to do something
# or not
flag_store = False
# 2.2 Loop over the clients
for i,out in enumerate(node.outputs[:c_outs]):
for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients
slices[i] = []
for cl,_ in out.clients:
for cl, _ in out.clients:
# 2.1 outputs of the function
#=> output needs all its intermediate values
if type(cl) == str:
# if the node is actually an output, then
# we need to store the entire thing
global_nsteps = None
slices[i] = None
global_nsteps = None
slices[i] = None
break
# 2.2 non-subtensor nodes
#=> output needs all its intermediate values
elif not isinstance(cl.op, tensor.basic.Subtensor):
global_nsteps = None
slices[i] = None
global_nsteps = None
slices[i] = None
break
# 2.3 subtensor nodes
#=> output might need to store just a subset of its values
......@@ -460,13 +462,11 @@ class ScanSaveMem(gof.Optimizer):
if this_slice == None:
# if unable to extract idx_list
#=> outputs needs all its intermediate values
global_nsteps = None
slices[i] = None
global_nsteps = None
slices[i] = None
break
# 2.3.2 extract the begin/end of the first dimension
if i > op.n_mit_mot:
try:
length = shape_of[out][0]
......@@ -479,26 +479,27 @@ class ScanSaveMem(gof.Optimizer):
length = out.shape[0]
cf_slice = tensor.basic.get_canonical_form_slice(
this_slice[0], length)
slices[i] += [(cf_slice,this_slice)]
slices[i] += [(cf_slice, this_slice)]
if ( isinstance(this_slice[0],slice) and
this_slice[0].stop is None ):
if (isinstance(this_slice[0], slice) and
this_slice[0].stop is None):
global_nsteps = None
break
if isinstance(cf_slice[0], slice):
stop = tensor.basic.extract_constant(cf_slice[0].stop)
stop = tensor.basic.extract_constant(cf_slice[0].stop)
else:
stop = tensor.basic.extract_constant(cf_slice[0]) + 1
stop = tensor.basic.extract_constant(cf_slice[0]) + 1
if stop == sys.maxint or stop == length:
stop = None
else:
# there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output as
# well. Which means that if have a initial state of
# length 3, and you look for 5 steps you get an output y
# of length 8. If you only use y[:5], this does not mean
# that you only need to loop for 5 steps but actually
# only for 2 steps ( the first 3 are the initial state)
# array that contains the initial state of the output
# as well. Which means that if have a initial state of
# length 3, and you look for 5 steps you get an output
# y of length 8. If you only use y[:5], this does not
# mean that you only need to loop for 5 steps but
# actually only for 2 steps ( the first 3 are the
# initial state)
stop = stop - init_l[i]
# 2.3.3 we might get away with less number of steps
......@@ -510,10 +511,11 @@ class ScanSaveMem(gof.Optimizer):
elif (type(stop) is int and stop == sys.maxint):
global_nsteps = None
# yes if it is a int k, 0 < k < maxint
elif (type(stop) is int and global_nsteps['real'] < stop):
elif (type(stop) is int and
global_nsteps['real'] < stop):
global_nsteps['real'] = stop
# yes if it is a int k, 0 < k < maxint
elif (type(stop) is int and stop > 0 ):
elif (type(stop) is int and stop > 0):
pass
# not otherwise
else:
......@@ -526,10 +528,10 @@ class ScanSaveMem(gof.Optimizer):
# there are some symbolic tensors that limit the number of
# steps
if len(global_nsteps['sym']) == 0 :
if len(global_nsteps['sym']) == 0:
sym_steps = None
else:
sym_steps =global_nsteps['sym'][0]
sym_steps = global_nsteps['sym'][0]
for c in global_nsteps['sym'][1:]:
sym_steps = tensor.maximum(sym_steps, c)
......@@ -543,12 +545,11 @@ class ScanSaveMem(gof.Optimizer):
nw_steps = node.inputs[0]
global_nsteps = None
# 2.4 Loop over the clients again now looking just to see how many
# intermediate steps to store
for i,out in enumerate(node.outputs[:c_outs]):
for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients
for cl,_ in out.clients:
for cl, _ in out.clients:
if type(cl) == str:
store_steps[i] = 0
break
......@@ -562,7 +563,7 @@ class ScanSaveMem(gof.Optimizer):
store_steps[i] = 0
break
if ( isinstance(this_slice[0],slice) and
if (isinstance(this_slice[0], slice) and
this_slice[0].start is None):
store_steps[i] = 0
break
......@@ -575,46 +576,48 @@ class ScanSaveMem(gof.Optimizer):
except Exception:
length = out.shape[0]
cf_slice = tensor.basic.get_canonical_form_slice(
this_slice[0],length)
this_slice[0], length)
if isinstance(cf_slice[0], slice):
start = tensor.basic.extract_constant(cf_slice[0].start)
start = tensor.basic.extract_constant(
cf_slice[0].start)
else:
start = tensor.basic.extract_constant(cf_slice[0])
if start == 0 or store_steps[i] == 0:
store_steps[i] = 0
else:
pval = select_max(nw_steps -start + init_l[i], init_l[i])
pval = select_max(nw_steps - start + init_l[i],
init_l[i])
if store_steps[i] != -1:
pval = select_max(pval, store_steps[i])
store_steps[i] = pval
flag_store = True
orphane_outs = [ i for i,x in enumerate(store_steps)
if (type(x) is int) and (x<0) ]
flag_store = flag_store or (len(orphane_outs) > 0 )
orphane_outs = [i for i, x in enumerate(store_steps)
if (type(x) is int) and (x < 0)]
flag_store = flag_store or (len(orphane_outs) > 0)
# 3. is there anything to change ?
if (flag_store or global_nsteps is not None):
# 3.1 initialize inputs for the new scan
old_outputs = []
nw_inputs = list(node.inputs)
old_outputs = []
nw_inputs = list(node.inputs)
nw_inputs[0] = nw_steps
# 3.2 check orphane outputs to see if we can eliminate any
required,not_required = \
scan_utils.scan_can_remove_outs(node.op
, orphane_outs)
required, not_required = \
scan_utils.scan_can_remove_outs(node.op,
orphane_outs)
# 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required
# by the inner function .. )
replaced_outs = []
offset = 1 + op.n_seqs + op.n_mit_mot
for idx,_val in enumerate(store_steps[op.n_mit_mot:]):
for idx, _val in enumerate(store_steps[op.n_mit_mot:]):
i = idx + op.n_mit_mot
if not( type(_val) is int and _val <=0 and i not in required):
if not(type(_val) is int and _val <= 0 and i not in required):
if idx+op.n_mit_mot in required:
if idx + op.n_mit_mot in required:
val = 1
else:
val = _val
......@@ -626,21 +629,21 @@ class ScanSaveMem(gof.Optimizer):
# a) the input is a set_subtensor, in that case we
# can replace the initial tensor by a slice,
# b) it is not, and we simply take a slice of it.
if (nw_inputs[offset+idx].owner and
isinstance(nw_inputs[offset+idx].owner.op,
if (nw_inputs[offset + idx].owner and
isinstance(nw_inputs[offset + idx].owner.op,
tensor.IncSubtensor)):
_nw_input = nw_inputs[offset+idx].owner.inputs[1]
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val - init_l[i]))
tmp = pre_constant_merge([tmp])[0]
nw_input = scan_utils.expand( _nw_input,tmp )
nw_input = scan_utils.expand(_nw_input, tmp)
else:
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val))
tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset+idx][:tmp]
nw_input = nw_inputs[offset + idx][:tmp]
nw_inputs[offset+idx] = nw_input
nw_inputs[offset + idx] = nw_input
replaced_outs.append(op.n_mit_mot + idx)
odx = op.n_mit_mot + idx
old_outputs += [(odx, [x[0].outputs[0] for x in
......@@ -648,8 +651,8 @@ class ScanSaveMem(gof.Optimizer):
# If there is no memory pre-allocated for this output
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
pos = ( op.n_mit_mot + idx + op.n_seqs
+ 1 + op.n_shared_outs )
pos = (op.n_mit_mot + idx + op.n_seqs +
1 + op.n_shared_outs)
if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = val
odx = op.n_mit_mot + idx
......@@ -662,43 +665,41 @@ class ScanSaveMem(gof.Optimizer):
for idx, val in enumerate(store_steps[op.n_mit_mot:]):
if val == 0:
if idx < op.n_mit_sot + op.n_sit_sot:
_nw_input = nw_inputs[offset+idx].owner.inputs[1]
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
odx = op.n_mit_mot + idx
nw_input = scan_utils.expand(_nw_input, nw_steps)
nw_inputs[offset+idx] = nw_input
nw_inputs[offset + idx] = nw_input
elif idx < (op.n_mit_sot + op.n_sit_sot +
+ op.n_nit_sot):
in_idx = offset+idx+op.n_shared_outs
op.n_nit_sot):
in_idx = offset + idx + op.n_shared_outs
if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] =nw_steps
nw_inputs[in_idx] = nw_steps
odx = op.n_mit_mot + idx
# 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = {}
for k,v in compress_map.items():
for k, v in compress_map.items():
inv_compress_map[v] = k
node_ins = [ pre_greedy_local_optimizer(list_opt_slice, x) for x in
node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in
node_ins]
node_ins = pre_constant_merge(node_ins)
# 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
info['_scan_merge_visited'] = True
new_outs = scan_op.Scan(inps
, outs
, info).make_node(*node_ins).outputs
new_outs = scan_op.Scan(inps,
outs,
info).make_node(*node_ins).outputs
old_new = []
# 3.7 Get replace pairs for those outputs that do not change
# the number of intermediate steps stored
for idx,sl in enumerate(slices):
for idx, sl in enumerate(slices):
if global_nsteps and sl is not None and store_steps[idx] == 0:
for hdx,cl in enumerate(node.outputs[idx].clients):
for hdx, cl in enumerate(node.outputs[idx].clients):
cnf_slice, old_slices = sl[hdx]
# Sanitize the nw_slice by converting ints back into
# constants :) I only need to do this for the first
......@@ -713,18 +714,16 @@ class ScanSaveMem(gof.Optimizer):
else:
fslice = sanitize(cnf_slice[0])
nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = inv_compress_map[idx]
nw_out = new_outs[nw_pos]
subtens = tensor.basic.Subtensor(nw_slice)
# slice inputs
sl_ins = tensor.basic.Subtensor.collapse(
nw_slice
, lambda entry: isinstance(entry
, tensor.Variable))
nw_slice,
lambda entry: isinstance(entry,
tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos],
*sl_ins).outputs[0]
if new_o.ndim > 0:
......@@ -737,34 +736,35 @@ class ScanSaveMem(gof.Optimizer):
if len(old_outs) > 0:
nw_pos = compress_map[pos]
nw_out = new_outs[nw_pos]
for k,old in enumerate(old_outs):
for k, old in enumerate(old_outs):
# Get the correct slice
cnf_slice, old_slices = slices[pos][k]
if type(cnf_slice[0]) is slice:
start = ( cnf_slice[0].start - nw_steps -
init_l[pos] + store_steps[pos] )
if ( cnf_slice[0].stop is not None and
cnf_slice[0].stop != sys.maxint ):
stop = ( cnf_slice[0].stop - nw_steps -
start = (cnf_slice[0].start - nw_steps -
init_l[pos] + store_steps[pos])
if (cnf_slice[0].stop is not None and
cnf_slice[0].stop != sys.maxint):
stop = (cnf_slice[0].stop - nw_steps -
init_l[pos] + store_steps[pos])
else:
stop = None
nw_slice = ( (slice(sanitize(start),
sanitize(stop),
sanitize(cnf_slice[0].step)),) +
tuple(old_slices[1:]) )
nw_slice = ((slice(sanitize(start),
sanitize(stop),
sanitize(cnf_slice[0].step)),)
+ tuple(old_slices[1:]))
else:
position = (cnf_slice[0] - nw_steps -
init_l[pos] + store_steps[pos] )
init_l[pos] + store_steps[pos])
nw_slice = (sanitize(position),) + tuple(old_slices[1:])
nw_slice = (sanitize(position),) + \
tuple(old_slices[1:])
subtens = tensor.basic.Subtensor(nw_slice)
sl_ins = tensor.basic.Subtensor.collapse(
nw_slice
, lambda entry: isinstance(entry
, tensor.Variable))
nw_slice,
lambda entry: isinstance(entry,
tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos],
*sl_ins).outputs[0]
if new_o.ndim > 0:
......@@ -773,13 +773,12 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes
if flag_store or global_nsteps is not None:
for idx,o in enumerate(node.outputs):
for idx, o in enumerate(node.outputs):
if not (idx in replaced_outs) and not idx in not_required:
nw_pos = compress_map[idx]
old_new += [(o,new_outs[nw_pos])]
env.replace_all_validate(old_new, reason = 'scan_save_mem')
old_new += [(o, new_outs[nw_pos])]
env.replace_all_validate(old_new, reason='scan_save_mem')
def apply(self, env):
......@@ -792,16 +791,16 @@ class ScanSaveMem(gof.Optimizer):
# Just before specialize to have the other optimization
# like constant folding being applied
# This don't introduce inplace.
scan_seqopt.register( 'scanOp_save_mem',
ScanSaveMem(),
4,
'fast_run',
'scan')
scan_seqopt.register('scanOp_save_mem',
ScanSaveMem(),
4,
'fast_run',
'scan')
class ScanMerge(gof.Optimizer):
""" Graph Optimizer that merges different scan ops """
def add_requirements(self,env):
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate())
def merge(self, nodes):
......@@ -812,29 +811,26 @@ class ScanMerge(gof.Optimizer):
else:
as_while = False
info = {}
info['tap_array'] = []
info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes])
info = {}
info['tap_array'] = []
info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes])
info['mit_mot_out_slices'] = []
info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes])
info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes])
info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes])
info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes])
info['truncate_gradient'] = nodes[0].op.truncate_gradient
info['name'] = '&'.join([nd.op.name for nd in nodes])
info['mode'] = nodes[0].op.mode
info['inplace'] = False
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = nodes[0].op.profile
inner_ins = []
outer_ins = []
info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes])
info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes])
info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes])
info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes])
info['truncate_gradient'] = nodes[0].op.truncate_gradient
info['name'] = '&'.join([nd.op.name for nd in nodes])
info['mode'] = nodes[0].op.mode
info['inplace'] = False
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = nodes[0].op.profile
inner_ins = []
outer_ins = []
inner_outs = []
outer_outs = []
......@@ -844,57 +840,56 @@ class ScanMerge(gof.Optimizer):
k.name += str(suffix)
return ls
for idx,nd in enumerate(nodes):
for idx, nd in enumerate(nodes):
# Seq
inner_ins += rename(nd.op.inner_seqs(),idx)
outer_ins += rename(nd.op.outer_seqs(nd),idx)
inner_ins += rename(nd.op.inner_seqs(), idx)
outer_ins += rename(nd.op.outer_seqs(nd), idx)
for idx,nd in enumerate(nodes):
for idx, nd in enumerate(nodes):
# MitMot
inner_ins += rename(nd.op.inner_mitmot(),idx)
inner_ins += rename(nd.op.inner_mitmot(), idx)
inner_outs += nd.op.inner_mitmot_outs()
info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd),idx)
outer_ins += rename(nd.op.outer_mitmot(nd), idx)
outer_outs += nd.op.outer_mitmot_outs(nd)
for idx,nd in enumerate(nodes):
for idx, nd in enumerate(nodes):
# MitSot
inner_ins += rename(nd.op.inner_mitsot(),idx)
inner_ins += rename(nd.op.inner_mitsot(), idx)
inner_outs += nd.op.inner_mitsot_outs()
info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd),idx)
outer_ins += rename(nd.op.outer_mitsot(nd), idx)
outer_outs += nd.op.outer_mitsot_outs(nd)
for idx,nd in enumerate(nodes):
for idx, nd in enumerate(nodes):
# SitSot
inner_ins += rename(nd.op.inner_sitsot(),idx)
inner_ins += rename(nd.op.inner_sitsot(), idx)
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs()
outer_ins += rename(nd.op.outer_sitsot(nd),idx)
outer_ins += rename(nd.op.outer_sitsot(nd), idx)
outer_outs += nd.op.outer_sitsot_outs(nd)
for idx,nd in enumerate(nodes):
for idx, nd in enumerate(nodes):
# Shared
inner_ins += rename(nd.op.inner_shared(),idx)
outer_ins += rename(nd.op.outer_shared(nd),idx)
inner_ins += rename(nd.op.inner_shared(), idx)
outer_ins += rename(nd.op.outer_shared(nd), idx)
for idx,nd in enumerate(nodes):
for idx, nd in enumerate(nodes):
# NitSot
inner_outs += nd.op.inner_nitsot_outs()
outer_ins += rename(nd.op.outer_nitsot(nd),idx)
outer_ins += rename(nd.op.outer_nitsot(nd), idx)
outer_outs += nd.op.outer_nitsot_outs(nd)
for idx,nd in enumerate(nodes):
for idx, nd in enumerate(nodes):
# Shared
outer_outs += nd.op.outer_shared_outs(nd)
inner_outs += nd.op.inner_shared_outs()
for idx,nd in enumerate(nodes):
for idx, nd in enumerate(nodes):
# Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(),idx)
outer_ins += rename(nd.op.outer_non_seqs(nd),idx)
inner_ins += rename(nd.op.inner_non_seqs(), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd), idx)
# Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins
......@@ -913,8 +908,6 @@ class ScanMerge(gof.Optimizer):
return zip(outer_outs, new_outs)
def belongs_to_set(self, node, set_nodes):
"""
This function checks if node `node` belongs to `set_nodes`, in the
......@@ -934,7 +927,6 @@ class ScanMerge(gof.Optimizer):
except TypeError:
pass
rep_nsteps = rep.inputs[0]
try:
rep_nsteps = int(get_constant_value(rep_nsteps))
......@@ -959,11 +951,9 @@ class ScanMerge(gof.Optimizer):
rep.op.inputs)
return same_cond and (nsteps == rep_nsteps) and can_add
def apply(self, env):
# Collect all scan nodes ordered according to toposort
scan_nodes = [ nd for nd in env.toposort()
scan_nodes = [nd for nd in env.toposort()
if isinstance(nd.op, scan_op.Scan)]
# All sets of possibly mergeable nodes
......@@ -971,7 +961,7 @@ class ScanMerge(gof.Optimizer):
for nd in scan_nodes:
belongs_to_set_idx = -1
for pos,subset in enumerate(all_sets):
for pos, subset in enumerate(all_sets):
if self.belongs_to_set(nd, subset):
assert belongs_to_set_idx == -1
belongs_to_set_idx = pos
......@@ -984,7 +974,7 @@ class ScanMerge(gof.Optimizer):
for subset in all_sets:
if len(subset) > 1:
proposal = self.merge(subset)
env.replace_all_validate(proposal, reason = 'scan_merge')
env.replace_all_validate(proposal, reason='scan_merge')
# after const merge but before stabilize so that we can have identity
......@@ -996,23 +986,27 @@ scan_seqopt.register('scanOp_merge',
'fast_run',
'scan')
def has_duplicates(l):
"""returns true if l has any duplicates (according to __eq__)."""
return len(set(l)) < len(l)
def make_equiv(lo, li):
"""builds a dictionary of equivalences between inner inputs based on the equivalence of their corresponding outer inputs."""
"""builds a dictionary of equivalences between inner inputs based on
the equivalence of their corresponding outer inputs."""
seeno = {}
left = []
left = []
right = []
for o, i in zip(lo, li):
if o in seeno:
left += [i]
left += [i]
right += [o]
else:
seeno[o] = i
return left, right
@gof.local_optimizer([None])
def scan_merge_inouts(node):
if not isinstance(node.op, scan_op.Scan):
......@@ -1072,58 +1066,68 @@ def scan_merge_inouts(node):
na = a
# start again
left = []
left = []
right = []
if has_duplicates(na.outer_in_shared):
_left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared)
left += _left
left += _left
right += _right
if has_duplicates(na.outer_in_sit_sot):
_left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot)
left += _left
left += _left
right += _right
if has_duplicates(na.outer_in_mit_mot):
seen = {}
for omm, imm, _sl in zip(na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices):
for omm, imm, _sl in zip(na.outer_in_mit_mot,
na.inner_in_mit_mot, na.mit_mot_in_slices):
sl = tuple(_sl)
if (omm, sl) in seen:
simm = seen[(omm, sl)]
left += imm
left += imm
right += simm
else:
seen[(omm, sl)] = imm
if has_duplicates(na.outer_in_mit_sot):
seen = {}
for oms, ims, _sl in zip(na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices):
for oms, ims, _sl in zip(na.outer_in_mit_sot,
na.inner_in_mit_sot,
na.mit_sot_in_slices):
sl = tuple(_sl)
if (oms, sl) in seen:
sims = seen[(oms, sl)]
left += ims
left += ims
right += sims
else:
seen[(oms, sl)] = ims
def map_out(i, o, seen):
for si, so in seen:
if equal_computations([i], [si],left, right):
if equal_computations([i], [si], left, right):
return so
seen.append((i, o))
return o
seen = []
na.outer_out_nit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_nit_sot, na.outer_out_nit_sot)]
na.outer_out_nit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot)]
seen = []
na.outer_out_sit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_sit_sot, na.outer_out_sit_sot)]
na.outer_out_sit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_sit_sot,
na.outer_out_sit_sot)]
seen = []
na.outer_out_mit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_mit_sot, na.outer_out_mit_sot)]
na.outer_out_mit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_mit_sot,
na.outer_out_mit_sot)]
seen = []
new_outer_out_mit_mot = []
for imm, omm, osl in zip(na.inner_out_mit_mot, na.outer_out_mit_mot, na.mit_mot_out_slices):
for imm, omm, osl in zip(na.inner_out_mit_mot,
na.outer_out_mit_mot, na.mit_mot_out_slices):
for simm, somm, sosl in seen:
if osl == sosl and equal_computations(imm, simm, left, right):
new_outer_out_mit_mot.append(somm)
......@@ -1136,7 +1140,7 @@ def scan_merge_inouts(node):
return na.outer_outputs
scan_seqopt.register('scanOp_merge_inouts',
opt.in2out(scan_merge_inouts,ignore_newtrees=True),
opt.in2out(scan_merge_inouts, ignore_newtrees=True),
3,
'fast_run',
'scan')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论