提交 34d8a4de authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixing PEP8 in several places

Note that I only went through the files and fix whatever the PEP8 script for vim reported as incorrect, but I had not actually paid attention to the text, so there might still be inconsistencies (like for example strange variable names, etc.). Also only scan_opt.py is done completely, scan_op.py still has a bunch of PEP8 inconsistency. After the fix I only run the tests in scan_module/tests.py and they passed so I assume I had not introduced any bug in my PEP8 fixes.
上级 1d7e3679
...@@ -5,10 +5,10 @@ See scan.py for details on scan ...@@ -5,10 +5,10 @@ See scan.py for details on scan
""" """
__docformat__ = 'restructedtext en' __docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu " __authors__ = ("Razvan Pascanu "
"Frederic Bastien " "Frederic Bastien "
"James Bergstra " "James Bergstra "
"Pascal Lamblin " ) "Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal" __copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>" __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
...@@ -39,15 +39,11 @@ _logger = logging.getLogger('theano.scan_module.scan_op') ...@@ -39,15 +39,11 @@ _logger = logging.getLogger('theano.scan_module.scan_op')
class Scan(PureOp): class Scan(PureOp):
# def __init__(self,
# OLD DOCUMENTATION CAN BE FOUND NEAR REVISION 2581 inputs,
# outputs,
info,
def __init__( self typeConstructor=None,
, inputs
, outputs
, info
, typeConstructor = None
): ):
""" """
:param inputs: inputs of the inner function of scan :param inputs: inputs of the inner function of scan
...@@ -56,7 +52,7 @@ class Scan(PureOp): ...@@ -56,7 +52,7 @@ class Scan(PureOp):
the scan op. the scan op.
""" """
# adding properties into self # adding properties into self
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.__dict__.update(info) self.__dict__.update(info)
# I keep a version of info in self, to use in __eq__ and __hash__, # I keep a version of info in self, to use in __eq__ and __hash__,
...@@ -70,15 +66,16 @@ class Scan(PureOp): ...@@ -70,15 +66,16 @@ class Scan(PureOp):
jdx = 0 jdx = 0
if typeConstructor is None: if typeConstructor is None:
typeConstructor = lambda broadcastable, dtype: TensorType( typeConstructor = lambda broadcastable, dtype: TensorType(
broadcastable = broadcastable, dtype = dtype) broadcastable=broadcastable, dtype=dtype)
while idx < self.n_mit_mot_outs: while idx < self.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per # Not that for mit_mot there are several output slices per
# output sequence # output sequence
o = outputs[idx] o = outputs[idx]
self.output_types.append( self.output_types.append(
typeConstructor( broadcastable = (False,) + o.type.broadcastable typeConstructor(
, dtype = o.type.dtype) broadcastable=(False,) + o.type.broadcastable,
dtype=o.type.dtype)
) )
idx += len(self.mit_mot_out_slices[jdx]) idx += len(self.mit_mot_out_slices[jdx])
jdx += 1 jdx += 1
...@@ -88,32 +85,32 @@ class Scan(PureOp): ...@@ -88,32 +85,32 @@ class Scan(PureOp):
for o in outputs[idx:end]: for o in outputs[idx:end]:
self.output_types.append( self.output_types.append(
typeConstructor( typeConstructor(
broadcastable = (False,) + o.type.broadcastable broadcastable=(False,) + o.type.broadcastable,
, dtype = o.type.dtype )) dtype=o.type.dtype))
# shared outputs + possibly the ending condition # shared outputs + possibly the ending condition
for o in outputs[end:]: for o in outputs[end:]:
self.output_types.append( o.type ) self.output_types.append(o.type)
if self.as_while: if self.as_while:
self.output_types = self.output_types[:-1] self.output_types = self.output_types[:-1]
self.destroy_map = {} self.destroy_map = {}
if hasattr(self,'inplace') and self.inplace: if hasattr(self, 'inplace') and self.inplace:
for idx in xrange(self.n_mit_mot + self.n_mit_sot + for idx in xrange(self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot ): self.n_sit_sot):
self.destroy_map[idx] = [idx + 1 + self.n_seqs] self.destroy_map[idx] = [idx + 1 + self.n_seqs]
mode_instance = compile.mode.get_mode(self.mode) mode_instance = compile.mode.get_mode(self.mode)
# if the default mode is used, and that mode is ProfileMode # if the default mode is used, and that mode is ProfileMode
# then we need to copy the mode otherwise the time for a given # then we need to copy the mode otherwise the time for a given
# op will be counted multiple times # op will be counted multiple times
if ( self.mode is None and if (self.mode is None and
isinstance(mode_instance, compile.profilemode.ProfileMode) ): isinstance(mode_instance, compile.profilemode.ProfileMode)):
mode_instance = compile.profilemode.ProfileMode( mode_instance = compile.profilemode.ProfileMode(
optimizer = mode_instance.provided_optimizer optimizer=mode_instance.provided_optimizer,
, linker = mode_instance.provided_linker ) linker=mode_instance.provided_linker)
compile.profilemode.prof_mode_instance_to_print.append(mode_instance) compile.profilemode.prof_mode_instance_to_print.append(
mode_instance)
self.mode_instance = mode_instance self.mode_instance = mode_instance
if self.name: if self.name:
self.mode_instance.message = self.name + " sub profile" self.mode_instance.message = self.name + " sub profile"
...@@ -122,7 +119,7 @@ class Scan(PureOp): ...@@ -122,7 +119,7 @@ class Scan(PureOp):
else: else:
self.mode_instance = mode_instance self.mode_instance = mode_instance
if not hasattr(self,'name') or self.name is None: if not hasattr(self, 'name') or self.name is None:
self.name = 'scan_fn' self.name = 'scan_fn'
# to have a fair __eq__ comparison later on, we update the info with # to have a fair __eq__ comparison later on, we update the info with
# the actual mode used to compile the function and the name of the # the actual mode used to compile the function and the name of the
...@@ -130,27 +127,26 @@ class Scan(PureOp): ...@@ -130,27 +127,26 @@ class Scan(PureOp):
self.info['name'] = self.name self.info['name'] = self.name
# Pre-computing some values to speed up perform # Pre-computing some values to speed up perform
self.mintaps = [ numpy.min(x) for x in self.tap_array] self.mintaps = [numpy.min(x) for x in self.tap_array]
self.mintaps += [ 0 for x in xrange(self.n_nit_sot) ] self.mintaps += [0 for x in xrange(self.n_nit_sot)]
self.seqs_arg_offset = 1+self.n_seqs self.seqs_arg_offset = 1 + self.n_seqs
self.shared_arg_offset = ( self.seqs_arg_offset self.shared_arg_offset = (self.seqs_arg_offset +
+ self.n_mit_mot self.n_mit_mot +
+ self.n_mit_sot self.n_mit_sot +
+ self.n_sit_sot ) self.n_sit_sot)
self.nit_sot_arg_offset = ( self.shared_arg_offset + self.nit_sot_arg_offset = (self.shared_arg_offset +
self.n_shared_outs ) self.n_shared_outs)
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
if not self.info['gpu']: if not self.info['gpu']:
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs, tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs) self.outputs)
local_env = gof.Env(tmp_in, tmp_out) local_env = gof.Env(tmp_in, tmp_out)
self._cmodule_key = gof.CLinker.cmodule_key_(local_env,[]) self._cmodule_key = gof.CLinker.cmodule_key_(local_env, [])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
else: else:
self._hash_inner_graph = self.info['gpu_hash'] self._hash_inner_graph = self.info['gpu_hash']
def make_node(self, *inputs): def make_node(self, *inputs):
assert numpy.all(isinstance(i, gof.Variable) for i in inputs) assert numpy.all(isinstance(i, gof.Variable) for i in inputs)
# assert dtype is consistent # assert dtype is consistent
...@@ -173,23 +169,23 @@ class Scan(PureOp): ...@@ -173,23 +169,23 @@ class Scan(PureOp):
# Flags that indicate which inputs are vectors # Flags that indicate which inputs are vectors
self.vector_seqs = [ seq.ndim == 1 for seq in self.vector_seqs = [seq.ndim == 1 for seq in
inputs[1:1+self.n_seqs ] ] inputs[1:1 + self.n_seqs]]
self.vector_outs = [ arg.ndim ==1 for arg in self.vector_outs = [arg.ndim == 1 for arg in
inputs[1+self.n_seqs: (1+self.n_seqs + inputs[1 + self.n_seqs: (1 + self.n_seqs +
self.n_outs)] ] self.n_outs)]]
self.vector_outs += [ False]*self.n_nit_sot self.vector_outs += [False] * self.n_nit_sot
# Check if input sequences and variables representing a slice of # Check if input sequences and variables representing a slice of
# them have the same dtype # them have the same dtype
for idx in xrange(self.n_seqs): for idx in xrange(self.n_seqs):
if inputs[1+idx].dtype != self.inputs[idx].dtype: if inputs[1 + idx].dtype != self.inputs[idx].dtype:
raise ValueError(err_msg1%( 'sequence' raise ValueError(err_msg1 % ('sequence',
, str(inputs[1+idx]) str(inputs[1 + idx]),
, idx idx,
, inputs[1+idx].dtype inputs[1 + idx].dtype,
, str(self.inputs[idx]) str(self.inputs[idx]),
, self.inputs[idx].dtype) ) self.inputs[idx].dtype))
# Check that this 3 things have the same dtype for mit_mot: # Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output # - initial state of the output
...@@ -198,73 +194,73 @@ class Scan(PureOp): ...@@ -198,73 +194,73 @@ class Scan(PureOp):
# Maybe checking that ndim fits would be good as well !? # Maybe checking that ndim fits would be good as well !?
index_i = self.n_seqs index_i = self.n_seqs
index_o = 0 index_o = 0
index = 1+self.n_seqs index = 1 + self.n_seqs
start = index start = index
end = index + self.n_mit_mot end = index + self.n_mit_mot
while index < end: while index < end:
for k in self.tap_array[index-start]: for k in self.tap_array[index - start]:
if inputs[index].dtype != self.inputs[index_i].dtype: if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1%( 'initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ' ' in scan nomenclature) ',
, str(inputs[index]) str(inputs[index]),
, index index,
, inputs[index].dtype inputs[index].dtype,
, str(self.inputs[index_i]) str(self.inputs[index_i]),
, self.inputs[index_i].dtype) ) self.inputs[index_i].dtype))
index_i += 1 index_i += 1
for k in self.mit_mot_out_slices[index-start]: for k in self.mit_mot_out_slices[index - start]:
if inputs[index].dtype != self.outputs[index_o].dtype: if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2%( str(inputs[index]) raise ValueError(err_msg2 % (str(inputs[index]),
, index index,
, inputs[index].dtype inputs[index].dtype,
, self.outputs[index_o].dtype) ) self.outputs[index_o].dtype))
index_o += 1 index_o += 1
index += 1 index += 1
# Same checks as above but for outputs of type mit_sot and sit_sot # Same checks as above but for outputs of type mit_sot and sit_sot
end += self.n_mit_sot + self.n_sit_sot end += self.n_mit_sot + self.n_sit_sot
while index < end: while index < end:
for k in self.tap_array[index-start]: for k in self.tap_array[index - start]:
if inputs[index].dtype != self.inputs[index_i].dtype: if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1%( 'Initial state' raise ValueError(err_msg1 % ('Initial state',
, str(inputs[index]) str(inputs[index]),
, index index,
, inputs[index].dtype inputs[index].dtype,
, str(self.inputs[index_i]) str(self.inputs[index_i]),
, self.inputs[index_i].dtype) ) self.inputs[index_i].dtype))
index_i += 1 index_i += 1
if inputs[index].dtype != self.outputs[index_o].dtype: if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2%( str(inputs[index]) raise ValueError(err_msg2 % (str(inputs[index]),
, index index,
, inputs[index].dtype inputs[index].dtype,
, self.outputs[index_o].dtype) ) self.outputs[index_o].dtype))
index_o += 1 index_o += 1
index += 1 index += 1
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?! # dtype. Maybe even same type ?!
end += self.n_shared_outs end += self.n_shared_outs
index_o += self.n_nit_sot index_o += self.n_nit_sot
while index < end: while index < end:
if (hasattr(inputs[index],'dtype') and if (hasattr(inputs[index], 'dtype') and
inputs[index].dtype != self.outputs[index_o].dtype): inputs[index].dtype != self.outputs[index_o].dtype):
raise ValueError(err_msg2%( str(inputs[index]) raise ValueError(err_msg2 % (str(inputs[index]),
, index index,
, inputs[index].dtype inputs[index].dtype,
, self.outputs[index_o].dtype) ) self.outputs[index_o].dtype))
index += 1 index += 1
index_o += 1 index_o += 1
for x in inputs[index:index+ self.n_nit_sot]: for x in inputs[index:index + self.n_nit_sot]:
# For every nit_sot input we get as input a int/uint that # For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is # depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization # used by truncated BPTT and by scan space optimization
if (str(x.dtype)[:3] not in ('uin','int') or if (str(x.dtype)[:3] not in ('uin', 'int') or
x.ndim != 0): x.ndim != 0):
raise ValueError('For output %d you need to provide a ' raise ValueError('For output %d you need to provide a '
'scalar int !',x) 'scalar int !', x)
apply_node = Apply(self apply_node = Apply(self,
, inputs inputs,
, [t() for t in self.output_types]) [t() for t in self.output_types])
return apply_node return apply_node
def __eq__(self, other): def __eq__(self, other):
...@@ -284,7 +280,7 @@ class Scan(PureOp): ...@@ -284,7 +280,7 @@ class Scan(PureOp):
# check. Namely, do the internal graph represent same # check. Namely, do the internal graph represent same
# computations # computations
for self_in, other_in in zip(self.inputs, other.inputs): for self_in, other_in in zip(self.inputs, other.inputs):
if self_in.type != other_in.type : if self_in.type != other_in.type:
return False return False
if not scan_utils.equal_computations(self.outputs, if not scan_utils.equal_computations(self.outputs,
...@@ -308,21 +304,19 @@ class Scan(PureOp): ...@@ -308,21 +304,19 @@ class Scan(PureOp):
else: else:
name = 'for' name = 'for'
if self.inplace : if self.inplace:
aux_txt = '%s{inplace,%s,%s}'%(name, gpu_str, str(self.name)) aux_txt = '%s{inplace,%s,%s}' % (name, gpu_str, str(self.name))
else: else:
aux_txt = '%s{%s,%s}'%(name,gpu_str, str(self.name)) aux_txt = '%s{%s,%s}' % (name, gpu_str, str(self.name))
return aux_txt return aux_txt
def __hash__(self): def __hash__(self):
return ( hash(type(self)) ^ return (hash(type(self)) ^
# and a hash representing the inner graph using the # and a hash representing the inner graph using the
# CLinker.cmodule_key_ # CLinker.cmodule_key_
self._hash_inner_graph ^ self._hash_inner_graph ^
scan_utils.hash_listsDictsTuples(self.info) ) scan_utils.hash_listsDictsTuples(self.info))
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_thunk(self, node, storage_map, compute_map, no_recycling):
""" """
...@@ -348,7 +342,6 @@ class Scan(PureOp): ...@@ -348,7 +342,6 @@ class Scan(PureOp):
# Setting up all my variables in what I believe is a more Cython # Setting up all my variables in what I believe is a more Cython
# friendly form # friendly form
node_input_storage = [storage_map[r] for r in node.inputs] node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs] node_input_compute = [compute_map[r] for r in node.inputs]
...@@ -357,64 +350,65 @@ class Scan(PureOp): ...@@ -357,64 +350,65 @@ class Scan(PureOp):
# If a shared variable is the result of a ViewOp it is a clear # If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of # indication that we need to copy that value after the perform of
# scan is done # scan is done
slices = ( self.n_mit_mot_outs + slices = (self.n_mit_mot_outs +
self.n_mit_sot + self.n_mit_sot +
self.n_sit_sot + self.n_sit_sot +
self.n_nit_sot ) self.n_nit_sot)
wrapped_inputs = [Param(x, borrow=True) for x in self.inputs ] wrapped_inputs = [Param(x, borrow=True) for x in self.inputs]
wrapped_outputs = [Out(x, borrow=True) for x in wrapped_outputs = [Out(x, borrow=True) for x in
self.outputs[:slices] ] self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:] wrapped_outputs += self.outputs[slices:]
profile = None profile = None
if (theano.config.profile or (isinstance(self.profile, (basestring, bool, int)) if (theano.config.profile or
(isinstance(self.profile, (basestring, bool, int))
and self.profile)): and self.profile)):
if isinstance(self.profile, basestring): if isinstance(self.profile, basestring):
profile = ScanProfileStats(name = self.profile) profile = ScanProfileStats(name=self.profile)
else: else:
profile = ScanProfileStats(name = self.name) profile = ScanProfileStats(name=self.name)
elif self.profile: elif self.profile:
profile = self.profile profile = self.profile
self.fn = function(wrapped_inputs, self.fn = function(wrapped_inputs,
wrapped_outputs, wrapped_outputs,
mode = self.mode_instance, mode=self.mode_instance,
name = self.name, name=self.name,
profile = profile) profile=profile)
try: try:
cython_mintaps = numpy.asarray(self.mintaps, dtype = 'int32') raise ImportError
cython_mintaps = numpy.asarray(self.mintaps, dtype='int32')
cython_tap_array_len = \ cython_tap_array_len = \
numpy.asarray([ len(x) for x in self.tap_array], numpy.asarray([len(x) for x in self.tap_array],
dtype='int32') dtype='int32')
if len(self.tap_array) == 0: if len(self.tap_array) == 0:
d1 = 0 d1 = 0
else: else:
d1 = numpy.max(cython_tap_array_len) d1 = numpy.max(cython_tap_array_len)
d0 = len(self.tap_array) d0 = len(self.tap_array)
cython_tap_array = numpy.zeros((d0,d1), dtype='int32') cython_tap_array = numpy.zeros((d0, d1), dtype='int32')
for _d0 in range(d0): for _d0 in range(d0):
for _d1 in range(cython_tap_array_len[_d0]): for _d1 in range(cython_tap_array_len[_d0]):
cython_tap_array[_d0,_d1] = self.tap_array[_d0][_d1] cython_tap_array[_d0, _d1] = self.tap_array[_d0][_d1]
cython_mit_mot_out_nslices = \ cython_mit_mot_out_nslices = \
numpy.asarray([ len(x) for x in self.mit_mot_out_slices], numpy.asarray([len(x) for x in self.mit_mot_out_slices],
dtype='int32') dtype='int32')
if len(self.mit_mot_out_slices) == 0: if len(self.mit_mot_out_slices) == 0:
d1 = 0 d1 = 0
else: else:
d1 = numpy.max(cython_mit_mot_out_nslices) d1 = numpy.max(cython_mit_mot_out_nslices)
d0 = len(self.mit_mot_out_slices) d0 = len(self.mit_mot_out_slices)
cython_mit_mot_out_slices = numpy.zeros((d0,d1), cython_mit_mot_out_slices = numpy.zeros((d0, d1),
dtype='int32') dtype='int32')
for _d0 in range(d0): for _d0 in range(d0):
for _d1 in range(cython_mit_mot_out_nslices[_d0]): for _d1 in range(cython_mit_mot_out_nslices[_d0]):
cython_mit_mot_out_slices[_d0,_d1] = \ cython_mit_mot_out_slices[_d0, _d1] = \
self.mit_mot_out_slices[_d0][_d1] self.mit_mot_out_slices[_d0][_d1]
vector_seqs = [ seq.ndim == 1 for seq in vector_seqs = [seq.ndim == 1 for seq in
node.inputs[1:1+self.n_seqs ] ] node.inputs[1:1 + self.n_seqs]]
vector_outs = [ arg.ndim ==1 for arg in vector_outs = [arg.ndim == 1 for arg in
node.inputs[1+self.n_seqs: (1+self.n_seqs + node.inputs[1 + self.n_seqs:
self.n_outs)] ] (1 + self.n_seqs + self.n_outs)]]
vector_outs += [ False]*self.n_nit_sot vector_outs += [False] * self.n_nit_sot
cython_vector_seqs = numpy.asarray(self.vector_seqs, cython_vector_seqs = numpy.asarray(self.vector_seqs,
dtype='int32') dtype='int32')
...@@ -448,6 +442,7 @@ class Scan(PureOp): ...@@ -448,6 +442,7 @@ class Scan(PureOp):
except ImportError: except ImportError:
p = self.execute p = self.execute
# default arguments are stored in the closure of `rval` # default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node): def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o) r = p(n, [x[0] for x in i], o)
for o in node.outputs: for o in node.outputs:
...@@ -463,14 +458,14 @@ class Scan(PureOp): ...@@ -463,14 +458,14 @@ class Scan(PureOp):
return self.inputs[:self.n_seqs] return self.inputs[:self.n_seqs]
def outer_seqs(self, node): def outer_seqs(self, node):
return node.inputs[1:1+self.n_seqs] return node.inputs[1:1 + self.n_seqs]
def inner_mitmot(self): def inner_mitmot(self):
n_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot]) n_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
return self.inputs[self.n_seqs: self.n_seqs + n_taps] return self.inputs[self.n_seqs: self.n_seqs + n_taps]
def outer_mitmot(self, node): def outer_mitmot(self, node):
return node.inputs[1+self.n_seqs:1+self.n_seqs+self.n_mit_mot] return node.inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self): def inner_mitmot_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
...@@ -490,80 +485,80 @@ class Scan(PureOp): ...@@ -490,80 +485,80 @@ class Scan(PureOp):
ntaps_upto_sit_sot = sum(len(x) for x in ntaps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
return self.inputs[self.n_seqs+n_mitmot_taps: return self.inputs[self.n_seqs + n_mitmot_taps:
self.n_seqs+ntaps_upto_sit_sot] self.n_seqs + ntaps_upto_sit_sot]
def outer_mitsot(self, node): def outer_mitsot(self, node):
offset = 1+self.n_seqs+self.n_mit_mot offset = 1 + self.n_seqs + self.n_mit_mot
return node.inputs[offset:offset+self.n_mit_sot] return node.inputs[offset:offset + self.n_mit_sot]
def inner_mitsot_outs(self): def inner_mitsot_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[n_taps:n_taps+self.n_mit_sot] return self.outputs[n_taps:n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, node): def outer_mitsot_outs(self, node):
return node.outputs[self.n_mit_mot:self.n_mit_mot+self.n_mit_sot] return node.outputs[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self): def mitsot_taps(self):
return self.tap_array[self.n_mit_mot:self.n_mit_mot+self.n_mit_sot] return self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]
def inner_sitsot(self): def inner_sitsot(self):
n_taps_upto_sit_sot = sum(len(x) for x in n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot offset = self.n_seqs + n_taps_upto_sit_sot
return self.inputs[offset:offset+self.n_sit_sot] return self.inputs[offset:offset + self.n_sit_sot]
def outer_sitsot(self,node): def outer_sitsot(self, node):
offset = 1+self.n_seqs+self.n_mit_mot + self.n_mit_sot offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return node.inputs[offset:offset+self.n_sit_sot] return node.inputs[offset:offset + self.n_sit_sot]
def inner_sitsot_outs(self): def inner_sitsot_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps offset = self.n_mit_sot + n_taps
return self.outputs[offset:offset+self.n_sit_sot] return self.outputs[offset:offset + self.n_sit_sot]
def outer_sitsot_outs(self, node): def outer_sitsot_outs(self, node):
offset = self.n_mit_mot + self.n_mit_sot offset = self.n_mit_mot + self.n_mit_sot
return node.outputs[offset:offset+self.n_sit_sot] return node.outputs[offset:offset + self.n_sit_sot]
def outer_nitsot(self, node): def outer_nitsot(self, node):
offset = (1 + self.n_seqs+self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_shared_outs) self.n_sit_sot + self.n_shared_outs)
return node.inputs[offset:offset+self.n_nit_sot] return node.inputs[offset:offset + self.n_nit_sot]
def inner_nitsot_outs(self): def inner_nitsot_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot offset = self.n_mit_sot + n_taps + self.n_sit_sot
return self.outputs[offset:offset+self.n_nit_sot] return self.outputs[offset:offset + self.n_nit_sot]
def outer_nitsot_outs(self, node): def outer_nitsot_outs(self, node):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot) offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot)
return node.outputs[offset:offset+self.n_nit_sot] return node.outputs[offset:offset + self.n_nit_sot]
def inner_shared(self): def inner_shared(self):
n_taps_upto_sit_sot = sum(len(x) for x in n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot
return self.inputs[offset:offset+self.n_shared_outs] return self.inputs[offset:offset + self.n_shared_outs]
def outer_shared(self, node): def outer_shared(self, node):
offset = (1 + self.n_seqs+self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot) self.n_sit_sot)
return node.inputs[offset:offset+self.n_shared_outs] return node.inputs[offset:offset + self.n_shared_outs]
def inner_shared_outs(self): def inner_shared_outs(self):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot
return self.outputs[offset:offset+self.n_shared_outs] return self.outputs[offset:offset + self.n_shared_outs]
def outer_shared_outs(self, node): def outer_shared_outs(self, node):
offset = ( self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot) self.n_nit_sot)
return node.outputs[offset:offset+self.n_shared_outs] return node.outputs[offset:offset + self.n_shared_outs]
def inner_non_seqs(self): def inner_non_seqs(self):
n_taps_upto_sit_sot = sum(len(x) for x in n_taps_upto_sit_sot = sum(len(x) for x in
...@@ -574,12 +569,11 @@ class Scan(PureOp): ...@@ -574,12 +569,11 @@ class Scan(PureOp):
return self.inputs[offset:] return self.inputs[offset:]
def outer_non_seqs(self, node): def outer_non_seqs(self, node):
offset = ( 1+ self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs) self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return node.inputs[offset:] return node.inputs[offset:]
def execute(self, node, args, outs):
def execute( self, node, args, outs):
""" """
The args are packed like this: The args are packed like this:
...@@ -607,7 +601,7 @@ class Scan(PureOp): ...@@ -607,7 +601,7 @@ class Scan(PureOp):
# negative flip sequences around, and make n_steps positive # negative flip sequences around, and make n_steps positive
t0_call = time.time() t0_call = time.time()
t_fn = 0 t_fn = 0
n_steps = args[0] n_steps = args[0]
seqs = [] seqs = []
if n_steps < 0: if n_steps < 0:
n_steps = abs(n_steps) n_steps = abs(n_steps)
...@@ -616,7 +610,7 @@ class Scan(PureOp): ...@@ -616,7 +610,7 @@ class Scan(PureOp):
raise ValueError(('Sequence is shorter then the required ' raise ValueError(('Sequence is shorter then the required '
'number of steps : (n_steps, seq, ' 'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps, 'seq.shape):'), n_steps,
node.inputs[1+idx], node.inputs[1 + idx],
seq.shape) seq.shape)
seqs.append(seq[::-1]) seqs.append(seq[::-1])
else: else:
...@@ -625,35 +619,37 @@ class Scan(PureOp): ...@@ -625,35 +619,37 @@ class Scan(PureOp):
raise ValueError(('Sequence is shorter then the required ' raise ValueError(('Sequence is shorter then the required '
'number of steps : (n_steps, seq, ' 'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps, 'seq.shape):'), n_steps,
node.inputs[1+idx], node.inputs[1 + idx],
seq.shape) seq.shape)
seqs.append(seq) seqs.append(seq)
# 2. Allocate memory for the outputs. Construct the list: # 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containting the length of each output # store_steps -- map containting the length of each output
# pos -- map containing the current position of each output # pos -- map containing the current position of each
# output
store_steps = [ arg.shape[0] for arg store_steps = [arg.shape[0] for arg
in args[self.seqs_arg_offset: in args[self.seqs_arg_offset:
self.shared_arg_offset] ] self.shared_arg_offset]]
store_steps += [ arg for arg in store_steps += [arg for arg in
args[self.nit_sot_arg_offset: args[self.nit_sot_arg_offset:
self.nit_sot_arg_offset+self.n_nit_sot] self.nit_sot_arg_offset + self.n_nit_sot]
] ]
pos = [ (-self.mintaps[idx])%store_steps[idx] for idx pos = [(-self.mintaps[idx]) % store_steps[idx] for idx
in xrange(self.n_outs+self.n_nit_sot)] in xrange(self.n_outs + self.n_nit_sot)]
# 2.1 Create storage space for outputs # 2.1 Create storage space for outputs
for idx in xrange(self.n_outs): for idx in xrange(self.n_outs):
if self.inplace: if self.inplace:
# ^ Case 1. Outputs should be computed inplace of their # ^ Case 1. Outputs should be computed inplace of their
# initial state # initial state
outs[idx][0] = args[self.seqs_arg_offset + idx ] outs[idx][0] = args[self.seqs_arg_offset + idx]
elif ( outs[idx][0] is not None and elif (outs[idx][0] is not None and
outs[idx][0].shape[1:] == args[self.seqs_arg_offset + idx].shape[1:] outs[idx][0].shape[1:] == args[self.seqs_arg_offset +
and outs[idx][0].shape[0] >= store_steps[idx] ): idx].shape[1:]
and outs[idx][0].shape[0] >= store_steps[idx]):
# Put in the values of the initial state # Put in the values of the initial state
outs[idx][0] = outs[idx][0][:store_steps[idx]] outs[idx][0] = outs[idx][0][:store_steps[idx]]
if idx > self.n_mit_mot: if idx > self.n_mit_mot:
l = - self.mintaps[idx] l = - self.mintaps[idx]
outs[idx][0][:l] = args[self.seqs_arg_offset + idx][:l] outs[idx][0][:l] = args[self.seqs_arg_offset + idx][:l]
...@@ -662,28 +658,28 @@ class Scan(PureOp): ...@@ -662,28 +658,28 @@ class Scan(PureOp):
else: else:
outs[idx][0] = args[self.seqs_arg_offset + idx].copy() outs[idx][0] = args[self.seqs_arg_offset + idx].copy()
offset = self.nit_sot_arg_offset + self.n_nit_sot offset = self.nit_sot_arg_offset + self.n_nit_sot
other_args = args[offset:] other_args = args[offset:]
input_storage = self.fn.input_storage input_storage = self.fn.input_storage
output_storage = self.fn.output_storage output_storage = self.fn.output_storage
fn = self.fn.fn fn = self.fn.fn
offset = ( self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) + offset = (self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) +
self.n_shared_outs) self.n_shared_outs)
for idx in xrange(len(other_args)): for idx in xrange(len(other_args)):
input_storage[idx+offset].storage[0] = other_args[idx] input_storage[idx + offset].storage[0] = other_args[idx]
i = 0 i = 0
cond = True cond = True
############## THE MAIN LOOP ######################### ############## THE MAIN LOOP #########################
#for i in xrange(n_steps): #for i in xrange(n_steps):
while (i< n_steps) and cond: while (i < n_steps) and cond:
# sequences over which scan iterates # sequences over which scan iterates
# 3. collect input slices # 3. collect input slices
for idx in xrange(self.n_seqs): for idx in xrange(self.n_seqs):
if self.vector_seqs[idx]: if self.vector_seqs[idx]:
input_storage[idx].storage[0] = seqs[idx][i:i+1].reshape(()) input_storage[idx].storage[0] = \
seqs[idx][i:i + 1].reshape(())
else: else:
input_storage[idx].storage[0] = seqs[idx][i] input_storage[idx].storage[0] = seqs[idx][i]
...@@ -691,26 +687,25 @@ class Scan(PureOp): ...@@ -691,26 +687,25 @@ class Scan(PureOp):
for idx in xrange(self.n_outs): for idx in xrange(self.n_outs):
if self.vector_outs[idx]: if self.vector_outs[idx]:
for tap in self.tap_array[idx]: for tap in self.tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos[idx] + tap) % store_steps[idx]
input_storage[offset].storage[0] =\ input_storage[offset].storage[0] =\
outs[idx][0][_idx:_idx+1].reshape(()) outs[idx][0][_idx:_idx + 1].reshape(())
offset += 1 offset += 1
else: else:
for tap in self.tap_array[idx]: for tap in self.tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos[idx] + tap) % store_steps[idx]
input_storage[offset].storage[0] = outs[idx][0][_idx] input_storage[offset].storage[0] = outs[idx][0][_idx]
offset += 1 offset += 1
a_offset = self.shared_arg_offset a_offset = self.shared_arg_offset
o_offset = self.n_outs + self.n_nit_sot o_offset = self.n_outs + self.n_nit_sot
if i == 0: if i == 0:
for j in xrange(self.n_shared_outs): for j in xrange(self.n_shared_outs):
input_storage[offset].storage[0] = args[a_offset+j] input_storage[offset].storage[0] = args[a_offset + j]
offset += 1 offset += 1
else: else:
for j in xrange(self.n_shared_outs): for j in xrange(self.n_shared_outs):
input_storage[offset].storage[0] = outs[o_offset+j][0] input_storage[offset].storage[0] = outs[o_offset + j][0]
offset += 1 offset += 1
# 4. collecting slices where the output should be stored # 4. collecting slices where the output should be stored
...@@ -718,23 +713,24 @@ class Scan(PureOp): ...@@ -718,23 +713,24 @@ class Scan(PureOp):
output_storage[idx].storage[0] = None output_storage[idx].storage[0] = None
offset = self.n_mit_mot_outs offset = self.n_mit_mot_outs
if i !=0 and self.n_nit_sot >0: if i != 0 and self.n_nit_sot > 0:
for idx in xrange(self.n_outs + self.n_nit_sot - for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot): self.n_mit_mot):
if ( store_steps[idx+self.n_mit_mot] == 1 or if (store_steps[idx + self.n_mit_mot] == 1 or
self.vector_outs[idx+self.n_mit_mot]): self.vector_outs[idx + self.n_mit_mot]):
output_storage[idx+offset].storage[0] = None output_storage[idx + offset].storage[0] = None
else: else:
output_storage[idx+offset].storage[0] =\ _pos0 = idx + self.n_mit_mot
outs[idx+self.n_mit_mot][0][pos[idx+self.n_mit_mot]] output_storage[idx + offset].storage[0] =\
outs[_pos0][0][pos[_pos0]]
else: else:
for idx in xrange(self.n_outs + self.n_nit_sot - for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot): self.n_mit_mot):
output_storage[idx+offset].storage[0] = None output_storage[idx + offset].storage[0] = None
offset += self.n_outs+self.n_nit_sot - self.n_mit_mot offset += self.n_outs + self.n_nit_sot - self.n_mit_mot
for idx in xrange(self.n_shared_outs): for idx in xrange(self.n_shared_outs):
output_storage[idx+offset].storage[0] = None output_storage[idx + offset].storage[0] = None
# If condition add it to the mix # If condition add it to the mix
if self.as_while: if self.as_while:
pdx = offset + self.n_shared_outs pdx = offset + self.n_shared_outs
...@@ -762,97 +758,102 @@ class Scan(PureOp): ...@@ -762,97 +758,102 @@ class Scan(PureOp):
# 5.1 Copy over the values for mit_mot outputs # 5.1 Copy over the values for mit_mot outputs
for j in xrange(self.n_mit_mot): for j in xrange(self.n_mit_mot):
for k in self.mit_mot_out_slices[j]: for k in self.mit_mot_out_slices[j]:
outs[j][0][k+pos[j]] = output_storage[offset_out].storage[0] outs[j][0][k + pos[j]] = \
output_storage[offset_out].storage[0]
offset_out += 1 offset_out += 1
# 5.2 Copy over the values for mit_sot/sit_sot outputs # 5.2 Copy over the values for mit_sot/sit_sot outputs
begin = self.n_mit_mot begin = self.n_mit_mot
end = self.n_outs end = self.n_outs
offset_out -= self.n_mit_mot offset_out -= self.n_mit_mot
for j in xrange(begin, end): for j in xrange(begin, end):
if ( store_steps[j] == 1 or self.vector_outs[j] or if (store_steps[j] == 1 or self.vector_outs[j] or
outs[j][0][pos[j]] is not output_storage[offset_out+j].storage[0]): outs[j][0][pos[j]] is not
output_storage[offset_out + j].storage[0]):
outs[j][0][pos[j]] = output_storage[offset_out+j].storage[0] outs[j][0][pos[j]] = \
output_storage[offset_out + j].storage[0]
# 5.3 Copy over the values for nit_sot outputs # 5.3 Copy over the values for nit_sot outputs
begin = end begin = end
end += self.n_nit_sot end += self.n_nit_sot
for j in xrange(begin,end): for j in xrange(begin, end):
if i == 0: if i == 0:
jout = j+offset_out jout = j + offset_out
shape = (store_steps[j],) + output_storage[jout].storage[0].shape shape = (store_steps[j],) + \
output_storage[jout].storage[0].shape
if len(output_storage[jout].storage[0].shape) == 0: if len(output_storage[jout].storage[0].shape) == 0:
self.vector_outs[j] = True self.vector_outs[j] = True
dtype = output_storage[jout].storage[0].dtype dtype = output_storage[jout].storage[0].dtype
if (outs[j][0] is None or if (outs[j][0] is None or
outs[j][0].shape[0] < store_steps[j] or outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or outs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype ): outs[j][0].dtype != dtype):
if self.gpu: if self.gpu:
outs[j][0] = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape) _cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
outs[j][0] = _cuda.zeros(shape)
else: else:
outs[j][0] = numpy.zeros(shape, dtype) outs[j][0] = numpy.zeros(shape, dtype)
elif outs[j][0].shape[0] != store_steps[j]: elif outs[j][0].shape[0] != store_steps[j]:
outs[j][0] = outs[j][0][:store_steps[j]] outs[j][0] = outs[j][0][:store_steps[j]]
outs[j][0][pos[j]] = output_storage[jout].storage[0] outs[j][0][pos[j]] = output_storage[jout].storage[0]
elif (store_steps[j] == 1 or self.vector_outs[j] or elif (store_steps[j] == 1 or self.vector_outs[j] or
outs[j][0][pos[j]] is not output_storage[j+offset_out].storage[0]): outs[j][0][pos[j]] is not
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0] output_storage[j + offset_out].storage[0]):
outs[j][0][pos[j]] = \
output_storage[j + offset_out].storage[0]
# 5.4 Copy over the values for outputs corresponding to shared # 5.4 Copy over the values for outputs corresponding to shared
# variables # variables
begin = end begin = end
end += self.n_shared_outs end += self.n_shared_outs
for j in xrange(begin,end): for j in xrange(begin, end):
jout = j +offset_out jout = j + offset_out
outs[j][0] = output_storage[jout].storage[0] outs[j][0] = output_storage[jout].storage[0]
pos = [ (idx+1)%store for idx,store in pos = [(idx + 1) % store for idx, store in
itertools.izip(pos, store_steps) itertools.izip(pos, store_steps)]
] i = i + 1
i = i+1
# 6. Check if you need to re-order output buffers # 6. Check if you need to re-order output buffers
begin = self.n_mit_mot begin = self.n_mit_mot
end = self.n_outs + self.n_nit_sot end = self.n_outs + self.n_nit_sot
for idx in xrange(begin, end): for idx in xrange(begin, end):
min_tap = self.mintaps[idx] min_tap = self.mintaps[idx]
if ( store_steps[idx] < i-self.mintaps[idx] and if (store_steps[idx] < i - self.mintaps[idx] and
pos[idx] < store_steps[idx] ): pos[idx] < store_steps[idx]):
pdx = pos[idx] pdx = pos[idx]
if pdx < store_steps[idx]//2 : if pdx < store_steps[idx] // 2:
shape = (pdx,)+ outs[idx][0].shape[1:] shape = (pdx,) + outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance( outs[idx][0], if cuda.cuda_available and isinstance(outs[idx][0],
cuda.CudaNdarray): cuda.CudaNdarray):
tmp = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape) _cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
tmp = _cuda.zeros(shape)
else: else:
tmp = numpy.empty(shape) tmp = numpy.empty(shape)
tmp[:] = outs[idx][0][:pdx] tmp[:] = outs[idx][0][:pdx]
outs[idx][0][:store_steps[idx]-pdx] = outs[idx][0][pdx:] outs[idx][0][:store_steps[idx] - pdx] = outs[idx][0][pdx:]
outs[idx][0][store_steps[idx]-pdx:] = tmp outs[idx][0][store_steps[idx] - pdx:] = tmp
del tmp del tmp
else: else:
shape = (store_steps[idx]-pdx,) + outs[idx][0].shape[1:] shape = (store_steps[idx] - pdx,) + outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance( outs[idx][0], if cuda.cuda_available and isinstance(outs[idx][0],
cuda.CudaNdarray): cuda.CudaNdarray):
tmp = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape) _cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
tmp = _cuda.zeros(shape)
else: else:
tmp = numpy.empty(shape) tmp = numpy.empty(shape)
tmp[:] = outs[idx][0][pdx:] tmp[:] = outs[idx][0][pdx:]
outs[idx][0][store_steps[idx]-pdx:] = outs[idx][0][:pdx] outs[idx][0][store_steps[idx] - pdx:] = outs[idx][0][:pdx]
outs[idx][0][:store_steps[idx]-pdx] = tmp outs[idx][0][:store_steps[idx] - pdx] = tmp
del tmp del tmp
# This would normally happen only when doing truncated # This would normally happen only when doing truncated
# backpropagation through time. In such a scenarion Scan is # backpropagation through time. In such a scenarion Scan is
# expected to return 0 for all entries for which the gradient is # expected to return 0 for all entries for which the gradient is
# not actually computed # not actually computed
elif store_steps[idx] > i - self.mintaps[idx]: elif store_steps[idx] > i - self.mintaps[idx]:
outs[idx][0][i-self.mintaps[idx]:] = 0 outs[idx][0][i - self.mintaps[idx]:] = 0
# This is a fix for a bug introduced by while. If you say # This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output # you want to loop up to a condition, you expect the output
# to have that length ( and not the maximal length possible) # to have that length ( and not the maximal length possible)
...@@ -883,7 +884,7 @@ class Scan(PureOp): ...@@ -883,7 +884,7 @@ class Scan(PureOp):
profile.callcount += 1 profile.callcount += 1
profile.nbsteps += n_steps profile.nbsteps += n_steps
profile.call_time += t_call profile.call_time += t_call
profile.vm_call_time += t_fn profile.vm_call_time += t_fn
if hasattr(self.fn.fn, 'update_profile'): if hasattr(self.fn.fn, 'update_profile'):
self.fn.fn.update_profile(profile) self.fn.fn.update_profile(profile)
...@@ -896,7 +897,7 @@ class Scan(PureOp): ...@@ -896,7 +897,7 @@ class Scan(PureOp):
#self.fn.maker.mode.fn_time += t_fn #self.fn.maker.mode.fn_time += t_fn
# Old Profile Mode */ # Old Profile Mode */
self.t_call = t_call self.t_call = t_call
self.t_fn = t_fn self.t_fn = t_fn
### Infer Shape ### Infer Shape
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
...@@ -905,26 +906,27 @@ class Scan(PureOp): ...@@ -905,26 +906,27 @@ class Scan(PureOp):
# is the shape of self.inputs[i] # is the shape of self.inputs[i]
# sequences # sequences
seqs_shape = [ x[1:] for x in input_shapes[1:1+self.n_seqs] ] seqs_shape = [x[1:] for x in input_shapes[1:1 + self.n_seqs]]
# mit_mot, mit_sot, sit_sot # mit_mot, mit_sot, sit_sot
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outs_shape = [] outs_shape = []
for idx in xrange(n_outs): for idx in xrange(n_outs):
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
outs_shape += [ input_shapes[idx+self.n_seqs+1][1:] ] outs_shape += [input_shapes[idx + self.n_seqs + 1][1:]]
# shared_outs # shared_outs
offset = 1 + self.n_seqs + n_outs offset = 1 + self.n_seqs + n_outs
for idx in xrange(self.n_shared_outs): for idx in xrange(self.n_shared_outs):
outs_shape += [ input_shapes[idx+offset] ] outs_shape += [input_shapes[idx + offset]]
# non_sequences # non_sequences
offset += self.n_nit_sot + self.n_shared_outs offset += self.n_nit_sot + self.n_shared_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:] inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inputs) assert len(inner_ins_shapes) == len(self.inputs)
# Non-sequences have a direct equivalent from self.inputs in node.inputs # Non-sequences have a direct equivalent from self.inputs in
# node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):] inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):]
out_equivalent = {} out_equivalent = {}
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]): for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
...@@ -934,22 +936,22 @@ class Scan(PureOp): ...@@ -934,22 +936,22 @@ class Scan(PureOp):
else: else:
self_outs = self.outputs self_outs = self.outputs
outs_shape = scan_utils.infer_shape( outs_shape = scan_utils.infer_shape(
outs = self_outs, outs=self_outs,
inputs = self.inputs, inputs=self.inputs,
input_shapes = inner_ins_shapes) input_shapes=inner_ins_shapes)
# Will be used to check if outs_shape can be expressed without using # Will be used to check if outs_shape can be expressed without using
# variables in self.inputs. # variables in self.inputs.
# The shapes of node.inputs are valid. # The shapes of node.inputs are valid.
validator = scan_utils.Validator( validator = scan_utils.Validator(
valid = input_shapes, valid=input_shapes,
invalid = self.inputs, invalid=self.inputs,
valid_equivalent = out_equivalent) valid_equivalent=out_equivalent)
offset = 1 + self.n_seqs offset = 1 + self.n_seqs
scan_outs = [x for x in input_shapes[offset:offset+n_outs]] scan_outs = [x for x in input_shapes[offset:offset + n_outs]]
offset += n_outs offset += n_outs
for x in xrange(self.n_nit_sot): for x in xrange(self.n_nit_sot):
out_shape_x = outs_shape[n_outs+x] out_shape_x = outs_shape[n_outs + x]
if out_shape_x is None: if out_shape_x is None:
# This output is not a tensor, and has no shape # This output is not a tensor, and has no shape
scan_outs.append(None) scan_outs.append(None)
...@@ -957,10 +959,10 @@ class Scan(PureOp): ...@@ -957,10 +959,10 @@ class Scan(PureOp):
# We need to make sure that we can compute the shapes from # We need to make sure that we can compute the shapes from
# node.inputs, and constants, without using the variables # node.inputs, and constants, without using the variables
# in the inner function. # in the inner function.
r = node.outputs[n_outs+x] r = node.outputs[n_outs + x]
assert r.ndim == 1 + len(out_shape_x) assert r.ndim == 1 + len(out_shape_x)
shp = [node.inputs[offset+self.n_shared_outs+x]] shp = [node.inputs[offset + self.n_shared_outs + x]]
for i, shp_i in zip(xrange(1,r.ndim), out_shape_x): for i, shp_i in zip(xrange(1, r.ndim), out_shape_x):
# Validate shp_i. v_shape_i is either None (if invalid), # Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates # or a (variable, Boolean) tuple. The Boolean indicates
# whether variable is shp_i (if True), or an valid # whether variable is shp_i (if True), or an valid
...@@ -976,34 +978,32 @@ class Scan(PureOp): ...@@ -976,34 +978,32 @@ class Scan(PureOp):
shp.append(v_shp_i[0]) shp.append(v_shp_i[0])
scan_outs.append(tuple(shp)) scan_outs.append(tuple(shp))
scan_outs += [ x for x in scan_outs += [x for x in
input_shapes[offset:offset+self.n_shared_outs] ] input_shapes[offset:offset + self.n_shared_outs]]
return scan_outs return scan_outs
### GRAD FUNCTION ### GRAD FUNCTION
def grad(self, args, g_outs): def grad(self, args, g_outs):
# 1. forward pass - get the outputs after applying scan # 1. forward pass - get the outputs after applying scan
scan_outputs = self(*args) scan_outputs = self(*args)
# 2. make sure they are given as a list # 2. make sure they are given as a list
if not( type(scan_outputs) in (list,tuple)): if not(type(scan_outputs) in (list, tuple)):
scan_outputs = [scan_outputs] scan_outputs = [scan_outputs]
# 3. un-group / unzip the inputs # 3. un-group / unzip the inputs
# Note ! We don't want to use the actual same variable as the ones # Note ! We don't want to use the actual same variable as the ones
# used by the original scan, rather create clones of them # used by the original scan, rather create clones of them
rval = scan_utils.reconstruct_graph(self.inputs, rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs,'_grad') self.outputs, '_grad')
self_inputs = rval[0] self_inputs = rval[0]
self_outputs = rval[1] self_outputs = rval[1]
seqs = self_inputs[:self.n_seqs]
seqs = self_inputs[:self.n_seqs] offset = self.n_seqs
n_ins_mit_mot = numpy.sum([0] + [len(self.tap_array[x]) for x
offset = self.n_seqs in xrange(self.n_mit_mot)])
n_ins_mit_mot = numpy.sum([0] + [ len(self.tap_array[x]) for x outs_mit_mot = self_inputs[offset:offset + n_ins_mit_mot]
in xrange(self.n_mit_mot) ])
outs_mit_mot = self_inputs[offset:offset+n_ins_mit_mot]
offset += n_ins_mit_mot offset += n_ins_mit_mot
n_ins_mit_sot = numpy.sum([0] + [ len(self.tap_array[x]) for x n_ins_mit_sot = numpy.sum([0] + [ len(self.tap_array[x]) for x
...@@ -1082,6 +1082,11 @@ class Scan(PureOp): ...@@ -1082,6 +1082,11 @@ class Scan(PureOp):
# 7.3. compute gradients of the inputs given one output # 7.3. compute gradients of the inputs given one output
for dx, out in enumerate(clean_outputs): for dx, out in enumerate(clean_outputs):
inner_g_out = safe_new(out) inner_g_out = safe_new(out)
###
#### I need to clip the gradient HERE !!
if g_outs_no_shared[dx]: if g_outs_no_shared[dx]:
g_out_slices.append(g_outs_no_shared[dx][0]) g_out_slices.append(g_outs_no_shared[dx][0])
else: else:
......
...@@ -4,11 +4,11 @@ This module provides optimizations for scan ...@@ -4,11 +4,11 @@ This module provides optimizations for scan
__docformat__ = 'restructedtext en' __docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu " __authors__ = ("Razvan Pascanu "
"Frederic Bastien " "Frederic Bastien "
"James Bergstra " "James Bergstra "
"Pascal Lamblin " "Pascal Lamblin "
"Arnaud Bergeron ") "Arnaud Bergeron ")
__copyright__ = "(c) 2010, Universite de Montreal" __copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>" __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
...@@ -32,16 +32,20 @@ from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer ...@@ -32,16 +32,20 @@ from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
# Logging function for sending warning or info # Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_opt') _logger = logging.getLogger('theano.scan_module.scan_opt')
list_opt_slice = [ tensor.opt.local_abs_merge, list_opt_slice = [tensor.opt.local_abs_merge,
tensor.opt.local_mul_switch_sink, tensor.opt.local_mul_switch_sink,
tensor.opt.local_upcast_elemwise_constant_inputs, tensor.opt.local_upcast_elemwise_constant_inputs,
tensor.opt.local_remove_switch_const_cond, tensor.opt.local_remove_switch_const_cond,
tensor.opt.constant_folding ] tensor.opt.constant_folding]
def warning(*msg): def warning(*msg):
_logger.warning('WARNING theano.scan: '+' '.join(msg)) _logger.warning('WARNING theano.scan: ' + ' '.join(msg))
def info(*msg): def info(*msg):
_logger.info('INFO theano.scan: '+' '.join(msg)) _logger.info('INFO theano.scan: ' + ' '.join(msg))
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def remove_constants_and_unused_inputs_scan(node): def remove_constants_and_unused_inputs_scan(node):
...@@ -58,9 +62,9 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -58,9 +62,9 @@ def remove_constants_and_unused_inputs_scan(node):
return False return False
op = node.op op = node.op
# We only need to take care of sequences and other arguments # We only need to take care of sequences and other arguments
st = op.n_seqs st = op.n_seqs
st += int(numpy.sum([len(x) for x in st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ])) op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
st += op.n_sit_sot st += op.n_sit_sot
st += op.n_shared_outs st += op.n_shared_outs
op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs) op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs)
...@@ -70,17 +74,17 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -70,17 +74,17 @@ def remove_constants_and_unused_inputs_scan(node):
out_stuff_inner = op_ins[op.n_seqs:st] out_stuff_inner = op_ins[op.n_seqs:st]
non_seqs = op_ins[st:] non_seqs = op_ins[st:]
st = ( op.n_seqs + st = (op.n_seqs +
op.n_mit_mot + op.n_mit_mot +
op.n_mit_sot + op.n_mit_sot +
op.n_sit_sot + op.n_sit_sot +
op.n_nit_sot + op.n_nit_sot +
op.n_shared_outs +1 ) op.n_shared_outs + 1)
outer_non_seqs = node.inputs[st:] outer_non_seqs = node.inputs[st:]
out_stuff_outer = node.inputs[1+op.n_seqs:st] out_stuff_outer = node.inputs[1 + op.n_seqs:st]
# To replace constants in the outer graph by clones in the inner graph # To replace constants in the outer graph by clones in the inner graph
givens = {} givens = {}
# All the inputs of the inner graph of the new scan # All the inputs of the inner graph of the new scan
nw_inner = [] nw_inner = []
# Same for the outer graph, initialized w/ number of steps # Same for the outer graph, initialized w/ number of steps
...@@ -88,18 +92,18 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -88,18 +92,18 @@ def remove_constants_and_unused_inputs_scan(node):
all_ins = gof.graph.inputs(op_outs) all_ins = gof.graph.inputs(op_outs)
for idx in xrange(op.n_seqs): for idx in xrange(op.n_seqs):
if (isinstance(node.inputs[idx+1], tensor.TensorConstant) and if (isinstance(node.inputs[idx + 1], tensor.TensorConstant) and
node.inputs[idx+1].tag.unique_value is not None): node.inputs[idx + 1].tag.unique_value is not None):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
# equal # equal
val = tensor.get_constant_value(node.inputs[idx+1]) val = tensor.get_constant_value(node.inputs[idx + 1])
givens[op_ins[idx]] = node.inputs[idx+1].clone()[0] givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0]
except TypeError: except TypeError:
pass pass
elif op_ins[idx] in all_ins: elif op_ins[idx] in all_ins:
nw_inner += [op_ins[idx]] nw_inner += [op_ins[idx]]
nw_outer += [node.inputs[idx+1]] nw_outer += [node.inputs[idx + 1]]
nw_n_seqs = len(nw_inner) nw_n_seqs = len(nw_inner)
# Add outputs stuff # Add outputs stuff
...@@ -114,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -114,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(node):
nw_outer += [nw_out] nw_outer += [nw_out]
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = scan_utils.clone(op_outs, replace = givens) op_outs = scan_utils.clone(op_outs, replace=givens)
nw_info = op.info.copy() nw_info = op.info.copy()
nw_info['n_seqs'] = nw_n_seqs nw_info['n_seqs'] = nw_n_seqs
# DEBUG CHECK # DEBUG CHECK
...@@ -128,11 +132,12 @@ scan_seqopt = theano.gof.SequenceDB() ...@@ -128,11 +132,12 @@ scan_seqopt = theano.gof.SequenceDB()
optdb.register('scan_seqopt', scan_seqopt, 1.9, 'fast_run', 'scan') optdb.register('scan_seqopt', scan_seqopt, 1.9, 'fast_run', 'scan')
scan_seqopt.register('scanOp_remove_constants_and_unused_inputs', scan_seqopt.register('scanOp_remove_constants_and_unused_inputs',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
ignore_newtrees = True), ignore_newtrees=True),
5, 5,
'fast_run', 'fast_run',
'scan') 'scan')
# This is a global opt for historical reason # This is a global opt for historical reason
# It should be possible to change it to a local opt. # It should be possible to change it to a local opt.
class PushOutNonSeqScan(gof.Optimizer): class PushOutNonSeqScan(gof.Optimizer):
...@@ -140,10 +145,9 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -140,10 +145,9 @@ class PushOutNonSeqScan(gof.Optimizer):
def __init__(self): def __init__(self):
gof.Optimizer.__init__(self) gof.Optimizer.__init__(self)
def add_requirements(self,env): def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate()) env.extend(gof.toolbox.ReplaceValidate())
def apply(self, env): def apply(self, env):
nodelist = [x for x in env.toposort() if isinstance(x.op, nodelist = [x for x in env.toposort() if isinstance(x.op,
scan_op.Scan)] scan_op.Scan)]
...@@ -152,34 +156,31 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -152,34 +156,31 @@ class PushOutNonSeqScan(gof.Optimizer):
def process_node(self, env, node): def process_node(self, env, node):
# this flag tells if there was any change during the last iterations # this flag tells if there was any change during the last iterations
changed = True
changed = True
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_env = gof.Env(clean_inputs, clean_outputs) local_env = gof.Env(clean_inputs, clean_outputs)
max_iterations = 2*len(local_env.toposort()) + 3 max_iterations = 2 * len(local_env.toposort()) + 3
counts = 0 counts = 0
to_remove = [] to_remove = []
to_replace = [] to_replace = []
replace_with_in = [] replace_with_in = []
replace_with_out = [] replace_with_out = []
op = node.op op = node.op
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
st = op.n_seqs st = op.n_seqs
st += int(numpy.sum([len(x) for x in st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ])) op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
st += op.n_sit_sot st += op.n_sit_sot
st += op.n_shared_outs st += op.n_shared_outs
non_seqs = clean_inputs[st:] non_seqs = clean_inputs[st:]
st = ( op.n_seqs + st = (op.n_seqs +
op.n_mit_mot + op.n_mit_mot +
op.n_mit_sot + op.n_mit_sot +
op.n_sit_sot + op.n_sit_sot +
op.n_nit_sot + op.n_nit_sot +
op.n_shared_outs +1 ) op.n_shared_outs + 1)
outer_non_seqs = node.inputs[st:] outer_non_seqs = node.inputs[st:]
assert len(non_seqs) == len(outer_non_seqs) assert len(non_seqs) == len(outer_non_seqs)
while changed and counts < max_iterations: while changed and counts < max_iterations:
...@@ -187,15 +188,15 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -187,15 +188,15 @@ class PushOutNonSeqScan(gof.Optimizer):
changed = False changed = False
for nd in local_env.toposort(): for nd in local_env.toposort():
if ( numpy.all([ (x in non_seqs) or if (numpy.all([(x in non_seqs) or
(x.owner in to_remove) or (x.owner in to_remove) or
isinstance(x, tensor.Constant) isinstance(x, tensor.Constant)
for x in nd.inputs]) and for x in nd.inputs]) and
# we can do this because the assumption is that a # we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the # viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle .. # function and not somewhere in the middle ..
not isinstance(nd.op,theano.compile.ViewOp) and not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op,theano.compile.DeepCopyOp) and not isinstance(nd.op, theano.compile.DeepCopyOp) and
# and we didn't already looked at this node # and we didn't already looked at this node
not nd in to_remove not nd in to_remove
): ):
...@@ -206,49 +207,50 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -206,49 +207,50 @@ class PushOutNonSeqScan(gof.Optimizer):
outside_ins = [] outside_ins = []
for x in nd.inputs: for x in nd.inputs:
if x in non_seqs: if x in non_seqs:
outside_ins +=[ outer_non_seqs[non_seqs.index(x)]] outside_ins += [outer_non_seqs[non_seqs.index(x)]]
elif x in to_replace: elif x in to_replace:
outside_ins +=[replace_with_out[to_replace.index(x)]] outside_ins += [
replace_with_out[to_replace.index(x)]]
elif isinstance(x, theano.Constant): elif isinstance(x, theano.Constant):
outside_ins +=[x.clone()] outside_ins += [x.clone()]
else: else:
raise Exception( raise Exception(
('Error in the `scan_pushout_non_seq_operations`' ('Error in the `scan_pushout_non_seq_'
'. The optimization tries to move some ' 'operations`. The optimization tries '
'computation fron scan which is not allowed ' 'to move some computation fron scan '
'to move. Report this on theano-users list'),x ) 'which is not allowed to move. Report '
'this on theano-users list'), x)
nw_outer_node = nd.op.make_node(*outside_ins) nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx,y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y,'_replace') y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y] to_replace += [y]
replace_with_in += [y_place_holder] replace_with_in += [y_place_holder]
assert type(y) == type(nw_outer_node.outputs[idx]) assert type(y) == type(nw_outer_node.outputs[idx])
replace_with_out += [nw_outer_node.outputs[idx]] replace_with_out += [nw_outer_node.outputs[idx]]
changed = True changed = True
if counts >= max_iterations: if counts >= max_iterations:
raise Exception( ('Error in the `scan_pushout_non_seq_operations`.' raise Exception('Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number ' ' The optimization exhausted the maximal number '
'of iterations allowed!')) 'of iterations allowed!')
# We need to check all candidate replacements and choose those that # We need to check all candidate replacements and choose those that
# make sense for us # make sense for us
# Step 1. which elements of `to_replace` are used by remaining # Step 1. which elements of `to_replace` are used by remaining
# components of the inner function # components of the inner function
clean_to_replace = [] clean_to_replace = []
clean_replace_with_in = [] clean_replace_with_in = []
clean_replace_with_out = [] clean_replace_with_out = []
existent_nodes = [ nd for nd in local_env.toposort() existent_nodes = [nd for nd in local_env.toposort()
if nd not in to_remove] if nd not in to_remove]
to_keep = [] to_keep = []
for nd in existent_nodes: for nd in existent_nodes:
to_keep += nd.inputs to_keep += nd.inputs
for idx,out in enumerate(to_replace): for idx, out in enumerate(to_replace):
if out in to_keep and out.owner not in existent_nodes: if out in to_keep and out.owner not in existent_nodes:
clean_to_replace += [out] clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]] clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]] clean_replace_with_out += [replace_with_out[idx]]
if len(clean_to_replace) > 0: if len(clean_to_replace) > 0:
...@@ -256,7 +258,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -256,7 +258,7 @@ class PushOutNonSeqScan(gof.Optimizer):
givens = {} givens = {}
nw_outer = [] nw_outer = []
nw_inner = [] nw_inner = []
for to_repl, repl_in, repl_out in zip( clean_to_replace, for to_repl, repl_in, repl_out in zip(clean_to_replace,
clean_replace_with_in, clean_replace_with_in,
clean_replace_with_out): clean_replace_with_out):
if isinstance(repl_out, theano.Constant): if isinstance(repl_out, theano.Constant):
...@@ -274,7 +276,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -274,7 +276,7 @@ class PushOutNonSeqScan(gof.Optimizer):
nwScan = scan_op.Scan(op_ins, op_outs, op.info) nwScan = scan_op.Scan(op_ins, op_outs, op.info)
nw_node = nwScan.make_node(* (node.inputs + nw_outer)) nw_node = nwScan.make_node(* (node.inputs + nw_outer))
env.replace_all_validate(zip(node.outputs, nw_node.outputs), env.replace_all_validate(zip(node.outputs, nw_node.outputs),
reason = 'scan_push_computation_out') reason='scan_push_computation_out')
return True return True
elif to_keep == []: elif to_keep == []:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
...@@ -290,7 +292,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -290,7 +292,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# We need to add one extra dimension to the outputs # We need to add one extra dimension to the outputs
env.replace_all_validate(replace_with.items(), env.replace_all_validate(replace_with.items(),
reason = 'scan_push_computation_out') reason='scan_push_computation_out')
else: else:
return False return False
...@@ -306,17 +308,17 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops', ...@@ -306,17 +308,17 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops',
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def scan_make_inplace(node): def scan_make_inplace(node):
op = node.op op = node.op
if ( isinstance(op, scan_op.Scan) and if (isinstance(op, scan_op.Scan) and
(not op.info['inplace']) and (not op.info['inplace']) and
(not op.info['gpu'])): (not op.info['gpu'])):
info = op.info.copy() info = op.info.copy()
info['inplace'] = True info['inplace'] = True
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1+op.n_seqs] ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node) ls = op.outer_mitmot(node)
ls += op.outer_mitsot(node) ls += op.outer_mitsot(node)
ls += op.outer_sitsot(node) ls += op.outer_sitsot(node)
ls_end = op.outer_shared(node) ls_end = op.outer_shared(node)
ls_end += op.outer_nitsot(node) ls_end += op.outer_nitsot(node)
ls_end += op.outer_non_seqs(node) ls_end += op.outer_non_seqs(node)
n_outs = len(ls) n_outs = len(ls)
...@@ -325,19 +327,18 @@ def scan_make_inplace(node): ...@@ -325,19 +327,18 @@ def scan_make_inplace(node):
ls[idx] = deep_copy_op(ls[idx]) ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end inputs = ls_begin + ls + ls_end
new_op = scan_op.Scan( op.inputs new_op = scan_op.Scan(op.inputs,
, op.outputs op.outputs,
, info) info)
return new_op.make_node(*inputs).outputs return new_op.make_node(*inputs).outputs
return False return False
optdb.register( 'scanOp_make_inplace' optdb.register('scanOp_make_inplace',
, opt.in2out(scan_make_inplace,ignore_newtrees=True) opt.in2out(scan_make_inplace, ignore_newtrees=True),
, 75 75,
, 'fast_run' 'fast_run',
, 'inplace' 'inplace',
, 'scan') 'scan')
class ScanSaveMem(gof.Optimizer): class ScanSaveMem(gof.Optimizer):
...@@ -345,24 +346,25 @@ class ScanSaveMem(gof.Optimizer): ...@@ -345,24 +346,25 @@ class ScanSaveMem(gof.Optimizer):
def __init__(self): def __init__(self):
gof.Optimizer.__init__(self) gof.Optimizer.__init__(self)
def add_requirements(self,env): def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate()) env.extend(gof.toolbox.ReplaceValidate())
def process_node(self, env, node): def process_node(self, env, node):
# helpful functions # helpful functions
def select_min(x,y): def select_min(x, y):
if x is None: if x is None:
return y return y
if y is None: if y is None:
return x return x
return tensor.minimum(x,y) return tensor.minimum(x, y)
def select_max(x,y):
def select_max(x, y):
if x is None: if x is None:
return y return y
if y is None: if y is None:
return x return x
return tensor.maximum(x,y) return tensor.maximum(x, y)
def sanitize(x): def sanitize(x):
if x is None: if x is None:
...@@ -383,9 +385,9 @@ class ScanSaveMem(gof.Optimizer): ...@@ -383,9 +385,9 @@ class ScanSaveMem(gof.Optimizer):
op = node.op op = node.op
c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot
init_l = [ 0 for x in xrange(op.n_mit_mot)] init_l = [0 for x in xrange(op.n_mit_mot)]
init_l += [ abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot:] ] init_l += [abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot:]]
init_l += [ 0 for x in xrange(op.n_nit_sot)] init_l += [0 for x in xrange(op.n_nit_sot)]
# 2. Check the clients of each output and see for how many steps # 2. Check the clients of each output and see for how many steps
# does scan need to run # does scan need to run
...@@ -408,13 +410,13 @@ class ScanSaveMem(gof.Optimizer): ...@@ -408,13 +410,13 @@ class ScanSaveMem(gof.Optimizer):
# change the number of steps in that case. To do this we set # change the number of steps in that case. To do this we set
# global_nsteps to None which is seen as a flag that nothing needs # global_nsteps to None which is seen as a flag that nothing needs
# to be done # to be done
if len(node.outputs) <= c_outs : if len(node.outputs) <= c_outs:
global_nsteps = {'real' :-1, 'sym': []} global_nsteps = {'real': -1, 'sym': []}
else: else:
global_nsteps = None global_nsteps = None
# Keeps track of the original slices that each client represent # Keeps track of the original slices that each client represent
slices = [ None for o in node.outputs] slices = [None for o in node.outputs]
# A list for each output indicating how many intermediate values # A list for each output indicating how many intermediate values
# should be stored. If negative it means none of the intermediate # should be stored. If negative it means none of the intermediate
...@@ -425,31 +427,31 @@ class ScanSaveMem(gof.Optimizer): ...@@ -425,31 +427,31 @@ class ScanSaveMem(gof.Optimizer):
# Note that for mit_mot outputs and shared outputs we can not change # Note that for mit_mot outputs and shared outputs we can not change
# the number of intermediate steps stored without affecting the # the number of intermediate steps stored without affecting the
# result of the op # result of the op
store_steps = [ 0 for o in xrange(op.n_mit_mot)] store_steps = [0 for o in xrange(op.n_mit_mot)]
store_steps += [-1 for o in node.outputs[op.n_mit_mot:c_outs]] store_steps += [-1 for o in node.outputs[op.n_mit_mot:c_outs]]
# Flag that says if an input has changed and we need to do something # Flag that says if an input has changed and we need to do something
# or not # or not
flag_store = False flag_store = False
# 2.2 Loop over the clients # 2.2 Loop over the clients
for i,out in enumerate(node.outputs[:c_outs]): for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients # look at all its clients
slices[i] = [] slices[i] = []
for cl,_ in out.clients: for cl, _ in out.clients:
# 2.1 outputs of the function # 2.1 outputs of the function
#=> output needs all its intermediate values #=> output needs all its intermediate values
if type(cl) == str: if type(cl) == str:
# if the node is actually an output, then # if the node is actually an output, then
# we need to store the entire thing # we need to store the entire thing
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
# 2.2 non-subtensor nodes # 2.2 non-subtensor nodes
#=> output needs all its intermediate values #=> output needs all its intermediate values
elif not isinstance(cl.op, tensor.basic.Subtensor): elif not isinstance(cl.op, tensor.basic.Subtensor):
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
# 2.3 subtensor nodes # 2.3 subtensor nodes
#=> output might need to store just a subset of its values #=> output might need to store just a subset of its values
...@@ -460,13 +462,11 @@ class ScanSaveMem(gof.Optimizer): ...@@ -460,13 +462,11 @@ class ScanSaveMem(gof.Optimizer):
if this_slice == None: if this_slice == None:
# if unable to extract idx_list # if unable to extract idx_list
#=> outputs needs all its intermediate values #=> outputs needs all its intermediate values
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
# 2.3.2 extract the begin/end of the first dimension # 2.3.2 extract the begin/end of the first dimension
if i > op.n_mit_mot: if i > op.n_mit_mot:
try: try:
length = shape_of[out][0] length = shape_of[out][0]
...@@ -479,26 +479,27 @@ class ScanSaveMem(gof.Optimizer): ...@@ -479,26 +479,27 @@ class ScanSaveMem(gof.Optimizer):
length = out.shape[0] length = out.shape[0]
cf_slice = tensor.basic.get_canonical_form_slice( cf_slice = tensor.basic.get_canonical_form_slice(
this_slice[0], length) this_slice[0], length)
slices[i] += [(cf_slice,this_slice)] slices[i] += [(cf_slice, this_slice)]
if ( isinstance(this_slice[0],slice) and if (isinstance(this_slice[0], slice) and
this_slice[0].stop is None ): this_slice[0].stop is None):
global_nsteps = None global_nsteps = None
break break
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
stop = tensor.basic.extract_constant(cf_slice[0].stop) stop = tensor.basic.extract_constant(cf_slice[0].stop)
else: else:
stop = tensor.basic.extract_constant(cf_slice[0]) + 1 stop = tensor.basic.extract_constant(cf_slice[0]) + 1
if stop == sys.maxint or stop == length: if stop == sys.maxint or stop == length:
stop = None stop = None
else: else:
# there is a **gotcha** here ! Namely, scan returns an # there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output as # array that contains the initial state of the output
# well. Which means that if have a initial state of # as well. Which means that if have a initial state of
# length 3, and you look for 5 steps you get an output y # length 3, and you look for 5 steps you get an output
# of length 8. If you only use y[:5], this does not mean # y of length 8. If you only use y[:5], this does not
# that you only need to loop for 5 steps but actually # mean that you only need to loop for 5 steps but
# only for 2 steps ( the first 3 are the initial state) # actually only for 2 steps ( the first 3 are the
# initial state)
stop = stop - init_l[i] stop = stop - init_l[i]
# 2.3.3 we might get away with less number of steps # 2.3.3 we might get away with less number of steps
...@@ -510,10 +511,11 @@ class ScanSaveMem(gof.Optimizer): ...@@ -510,10 +511,11 @@ class ScanSaveMem(gof.Optimizer):
elif (type(stop) is int and stop == sys.maxint): elif (type(stop) is int and stop == sys.maxint):
global_nsteps = None global_nsteps = None
# yes if it is a int k, 0 < k < maxint # yes if it is a int k, 0 < k < maxint
elif (type(stop) is int and global_nsteps['real'] < stop): elif (type(stop) is int and
global_nsteps['real'] < stop):
global_nsteps['real'] = stop global_nsteps['real'] = stop
# yes if it is a int k, 0 < k < maxint # yes if it is a int k, 0 < k < maxint
elif (type(stop) is int and stop > 0 ): elif (type(stop) is int and stop > 0):
pass pass
# not otherwise # not otherwise
else: else:
...@@ -526,10 +528,10 @@ class ScanSaveMem(gof.Optimizer): ...@@ -526,10 +528,10 @@ class ScanSaveMem(gof.Optimizer):
# there are some symbolic tensors that limit the number of # there are some symbolic tensors that limit the number of
# steps # steps
if len(global_nsteps['sym']) == 0 : if len(global_nsteps['sym']) == 0:
sym_steps = None sym_steps = None
else: else:
sym_steps =global_nsteps['sym'][0] sym_steps = global_nsteps['sym'][0]
for c in global_nsteps['sym'][1:]: for c in global_nsteps['sym'][1:]:
sym_steps = tensor.maximum(sym_steps, c) sym_steps = tensor.maximum(sym_steps, c)
...@@ -543,12 +545,11 @@ class ScanSaveMem(gof.Optimizer): ...@@ -543,12 +545,11 @@ class ScanSaveMem(gof.Optimizer):
nw_steps = node.inputs[0] nw_steps = node.inputs[0]
global_nsteps = None global_nsteps = None
# 2.4 Loop over the clients again now looking just to see how many # 2.4 Loop over the clients again now looking just to see how many
# intermediate steps to store # intermediate steps to store
for i,out in enumerate(node.outputs[:c_outs]): for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients # look at all its clients
for cl,_ in out.clients: for cl, _ in out.clients:
if type(cl) == str: if type(cl) == str:
store_steps[i] = 0 store_steps[i] = 0
break break
...@@ -562,7 +563,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -562,7 +563,7 @@ class ScanSaveMem(gof.Optimizer):
store_steps[i] = 0 store_steps[i] = 0
break break
if ( isinstance(this_slice[0],slice) and if (isinstance(this_slice[0], slice) and
this_slice[0].start is None): this_slice[0].start is None):
store_steps[i] = 0 store_steps[i] = 0
break break
...@@ -575,46 +576,48 @@ class ScanSaveMem(gof.Optimizer): ...@@ -575,46 +576,48 @@ class ScanSaveMem(gof.Optimizer):
except Exception: except Exception:
length = out.shape[0] length = out.shape[0]
cf_slice = tensor.basic.get_canonical_form_slice( cf_slice = tensor.basic.get_canonical_form_slice(
this_slice[0],length) this_slice[0], length)
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
start = tensor.basic.extract_constant(cf_slice[0].start) start = tensor.basic.extract_constant(
cf_slice[0].start)
else: else:
start = tensor.basic.extract_constant(cf_slice[0]) start = tensor.basic.extract_constant(cf_slice[0])
if start == 0 or store_steps[i] == 0: if start == 0 or store_steps[i] == 0:
store_steps[i] = 0 store_steps[i] = 0
else: else:
pval = select_max(nw_steps -start + init_l[i], init_l[i]) pval = select_max(nw_steps - start + init_l[i],
init_l[i])
if store_steps[i] != -1: if store_steps[i] != -1:
pval = select_max(pval, store_steps[i]) pval = select_max(pval, store_steps[i])
store_steps[i] = pval store_steps[i] = pval
flag_store = True flag_store = True
orphane_outs = [ i for i,x in enumerate(store_steps) orphane_outs = [i for i, x in enumerate(store_steps)
if (type(x) is int) and (x<0) ] if (type(x) is int) and (x < 0)]
flag_store = flag_store or (len(orphane_outs) > 0 ) flag_store = flag_store or (len(orphane_outs) > 0)
# 3. is there anything to change ? # 3. is there anything to change ?
if (flag_store or global_nsteps is not None): if (flag_store or global_nsteps is not None):
# 3.1 initialize inputs for the new scan # 3.1 initialize inputs for the new scan
old_outputs = [] old_outputs = []
nw_inputs = list(node.inputs) nw_inputs = list(node.inputs)
nw_inputs[0] = nw_steps nw_inputs[0] = nw_steps
# 3.2 check orphane outputs to see if we can eliminate any # 3.2 check orphane outputs to see if we can eliminate any
required,not_required = \ required, not_required = \
scan_utils.scan_can_remove_outs(node.op scan_utils.scan_can_remove_outs(node.op,
, orphane_outs) orphane_outs)
# 3.3. compose replace pairs for those nodes that need not # 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required # to store everything in memory ( or ar orphane and required
# by the inner function .. ) # by the inner function .. )
replaced_outs = [] replaced_outs = []
offset = 1 + op.n_seqs + op.n_mit_mot offset = 1 + op.n_seqs + op.n_mit_mot
for idx,_val in enumerate(store_steps[op.n_mit_mot:]): for idx, _val in enumerate(store_steps[op.n_mit_mot:]):
i = idx + op.n_mit_mot i = idx + op.n_mit_mot
if not( type(_val) is int and _val <=0 and i not in required): if not(type(_val) is int and _val <= 0 and i not in required):
if idx+op.n_mit_mot in required: if idx + op.n_mit_mot in required:
val = 1 val = 1
else: else:
val = _val val = _val
...@@ -626,21 +629,21 @@ class ScanSaveMem(gof.Optimizer): ...@@ -626,21 +629,21 @@ class ScanSaveMem(gof.Optimizer):
# a) the input is a set_subtensor, in that case we # a) the input is a set_subtensor, in that case we
# can replace the initial tensor by a slice, # can replace the initial tensor by a slice,
# b) it is not, and we simply take a slice of it. # b) it is not, and we simply take a slice of it.
if (nw_inputs[offset+idx].owner and if (nw_inputs[offset + idx].owner and
isinstance(nw_inputs[offset+idx].owner.op, isinstance(nw_inputs[offset + idx].owner.op,
tensor.IncSubtensor)): tensor.IncSubtensor)):
_nw_input = nw_inputs[offset+idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val - init_l[i])) tensor.as_tensor_variable(val - init_l[i]))
tmp = pre_constant_merge([tmp])[0] tmp = pre_constant_merge([tmp])[0]
nw_input = scan_utils.expand( _nw_input,tmp ) nw_input = scan_utils.expand(_nw_input, tmp)
else: else:
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val)) tensor.as_tensor_variable(val))
tmp = pre_constant_merge([tmp])[0] tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset+idx][:tmp] nw_input = nw_inputs[offset + idx][:tmp]
nw_inputs[offset+idx] = nw_input nw_inputs[offset + idx] = nw_input
replaced_outs.append(op.n_mit_mot + idx) replaced_outs.append(op.n_mit_mot + idx)
odx = op.n_mit_mot + idx odx = op.n_mit_mot + idx
old_outputs += [(odx, [x[0].outputs[0] for x in old_outputs += [(odx, [x[0].outputs[0] for x in
...@@ -648,8 +651,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -648,8 +651,8 @@ class ScanSaveMem(gof.Optimizer):
# If there is no memory pre-allocated for this output # If there is no memory pre-allocated for this output
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
pos = ( op.n_mit_mot + idx + op.n_seqs pos = (op.n_mit_mot + idx + op.n_seqs +
+ 1 + op.n_shared_outs ) 1 + op.n_shared_outs)
if nw_inputs[pos] == node.inputs[0]: if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = val nw_inputs[pos] = val
odx = op.n_mit_mot + idx odx = op.n_mit_mot + idx
...@@ -662,43 +665,41 @@ class ScanSaveMem(gof.Optimizer): ...@@ -662,43 +665,41 @@ class ScanSaveMem(gof.Optimizer):
for idx, val in enumerate(store_steps[op.n_mit_mot:]): for idx, val in enumerate(store_steps[op.n_mit_mot:]):
if val == 0: if val == 0:
if idx < op.n_mit_sot + op.n_sit_sot: if idx < op.n_mit_sot + op.n_sit_sot:
_nw_input = nw_inputs[offset+idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
odx = op.n_mit_mot + idx odx = op.n_mit_mot + idx
nw_input = scan_utils.expand(_nw_input, nw_steps) nw_input = scan_utils.expand(_nw_input, nw_steps)
nw_inputs[offset+idx] = nw_input nw_inputs[offset + idx] = nw_input
elif idx < (op.n_mit_sot + op.n_sit_sot + elif idx < (op.n_mit_sot + op.n_sit_sot +
+ op.n_nit_sot): op.n_nit_sot):
in_idx = offset+idx+op.n_shared_outs in_idx = offset + idx + op.n_shared_outs
if nw_inputs[in_idx] == node.inputs[0]: if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] =nw_steps nw_inputs[in_idx] = nw_steps
odx = op.n_mit_mot + idx odx = op.n_mit_mot + idx
# 3.5 Remove unwanted orphane outputs # 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \ (inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs) scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = {} inv_compress_map = {}
for k,v in compress_map.items(): for k, v in compress_map.items():
inv_compress_map[v] = k inv_compress_map[v] = k
node_ins = [ pre_greedy_local_optimizer(list_opt_slice, x) for x in node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in
node_ins] node_ins]
node_ins = pre_constant_merge(node_ins) node_ins = pre_constant_merge(node_ins)
# 3.6 Compose the new scan # 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization # I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that # twice since bad things usually happen if I do that
info['_scan_merge_visited'] = True info['_scan_merge_visited'] = True
new_outs = scan_op.Scan(inps new_outs = scan_op.Scan(inps,
, outs outs,
, info).make_node(*node_ins).outputs info).make_node(*node_ins).outputs
old_new = [] old_new = []
# 3.7 Get replace pairs for those outputs that do not change # 3.7 Get replace pairs for those outputs that do not change
# the number of intermediate steps stored # the number of intermediate steps stored
for idx,sl in enumerate(slices): for idx, sl in enumerate(slices):
if global_nsteps and sl is not None and store_steps[idx] == 0: if global_nsteps and sl is not None and store_steps[idx] == 0:
for hdx,cl in enumerate(node.outputs[idx].clients): for hdx, cl in enumerate(node.outputs[idx].clients):
cnf_slice, old_slices = sl[hdx] cnf_slice, old_slices = sl[hdx]
# Sanitize the nw_slice by converting ints back into # Sanitize the nw_slice by converting ints back into
# constants :) I only need to do this for the first # constants :) I only need to do this for the first
...@@ -713,18 +714,16 @@ class ScanSaveMem(gof.Optimizer): ...@@ -713,18 +714,16 @@ class ScanSaveMem(gof.Optimizer):
else: else:
fslice = sanitize(cnf_slice[0]) fslice = sanitize(cnf_slice[0])
nw_slice = (fslice,) + tuple(old_slices[1:]) nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = inv_compress_map[idx] nw_pos = inv_compress_map[idx]
nw_out = new_outs[nw_pos] nw_out = new_outs[nw_pos]
subtens = tensor.basic.Subtensor(nw_slice) subtens = tensor.basic.Subtensor(nw_slice)
# slice inputs # slice inputs
sl_ins = tensor.basic.Subtensor.collapse( sl_ins = tensor.basic.Subtensor.collapse(
nw_slice nw_slice,
, lambda entry: isinstance(entry lambda entry: isinstance(entry,
, tensor.Variable)) tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos], new_o = subtens.make_node(new_outs[nw_pos],
*sl_ins).outputs[0] *sl_ins).outputs[0]
if new_o.ndim > 0: if new_o.ndim > 0:
...@@ -737,34 +736,35 @@ class ScanSaveMem(gof.Optimizer): ...@@ -737,34 +736,35 @@ class ScanSaveMem(gof.Optimizer):
if len(old_outs) > 0: if len(old_outs) > 0:
nw_pos = compress_map[pos] nw_pos = compress_map[pos]
nw_out = new_outs[nw_pos] nw_out = new_outs[nw_pos]
for k,old in enumerate(old_outs): for k, old in enumerate(old_outs):
# Get the correct slice # Get the correct slice
cnf_slice, old_slices = slices[pos][k] cnf_slice, old_slices = slices[pos][k]
if type(cnf_slice[0]) is slice: if type(cnf_slice[0]) is slice:
start = ( cnf_slice[0].start - nw_steps - start = (cnf_slice[0].start - nw_steps -
init_l[pos] + store_steps[pos] ) init_l[pos] + store_steps[pos])
if ( cnf_slice[0].stop is not None and if (cnf_slice[0].stop is not None and
cnf_slice[0].stop != sys.maxint ): cnf_slice[0].stop != sys.maxint):
stop = ( cnf_slice[0].stop - nw_steps - stop = (cnf_slice[0].stop - nw_steps -
init_l[pos] + store_steps[pos]) init_l[pos] + store_steps[pos])
else: else:
stop = None stop = None
nw_slice = ( (slice(sanitize(start), nw_slice = ((slice(sanitize(start),
sanitize(stop), sanitize(stop),
sanitize(cnf_slice[0].step)),) + sanitize(cnf_slice[0].step)),)
tuple(old_slices[1:]) ) + tuple(old_slices[1:]))
else: else:
position = (cnf_slice[0] - nw_steps - position = (cnf_slice[0] - nw_steps -
init_l[pos] + store_steps[pos] ) init_l[pos] + store_steps[pos])
nw_slice = (sanitize(position),) + tuple(old_slices[1:]) nw_slice = (sanitize(position),) + \
tuple(old_slices[1:])
subtens = tensor.basic.Subtensor(nw_slice) subtens = tensor.basic.Subtensor(nw_slice)
sl_ins = tensor.basic.Subtensor.collapse( sl_ins = tensor.basic.Subtensor.collapse(
nw_slice nw_slice,
, lambda entry: isinstance(entry lambda entry: isinstance(entry,
, tensor.Variable)) tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos], new_o = subtens.make_node(new_outs[nw_pos],
*sl_ins).outputs[0] *sl_ins).outputs[0]
if new_o.ndim > 0: if new_o.ndim > 0:
...@@ -773,13 +773,12 @@ class ScanSaveMem(gof.Optimizer): ...@@ -773,13 +773,12 @@ class ScanSaveMem(gof.Optimizer):
# 3.9. Get replace pairs for all other nodes # 3.9. Get replace pairs for all other nodes
if flag_store or global_nsteps is not None: if flag_store or global_nsteps is not None:
for idx,o in enumerate(node.outputs): for idx, o in enumerate(node.outputs):
if not (idx in replaced_outs) and not idx in not_required: if not (idx in replaced_outs) and not idx in not_required:
nw_pos = compress_map[idx] nw_pos = compress_map[idx]
old_new += [(o,new_outs[nw_pos])] old_new += [(o, new_outs[nw_pos])]
env.replace_all_validate(old_new, reason = 'scan_save_mem')
env.replace_all_validate(old_new, reason='scan_save_mem')
def apply(self, env): def apply(self, env):
...@@ -792,16 +791,16 @@ class ScanSaveMem(gof.Optimizer): ...@@ -792,16 +791,16 @@ class ScanSaveMem(gof.Optimizer):
# Just before specialize to have the other optimization # Just before specialize to have the other optimization
# like constant folding being applied # like constant folding being applied
# This don't introduce inplace. # This don't introduce inplace.
scan_seqopt.register( 'scanOp_save_mem', scan_seqopt.register('scanOp_save_mem',
ScanSaveMem(), ScanSaveMem(),
4, 4,
'fast_run', 'fast_run',
'scan') 'scan')
class ScanMerge(gof.Optimizer): class ScanMerge(gof.Optimizer):
""" Graph Optimizer that merges different scan ops """ """ Graph Optimizer that merges different scan ops """
def add_requirements(self,env): def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate()) env.extend(gof.toolbox.ReplaceValidate())
def merge(self, nodes): def merge(self, nodes):
...@@ -812,29 +811,26 @@ class ScanMerge(gof.Optimizer): ...@@ -812,29 +811,26 @@ class ScanMerge(gof.Optimizer):
else: else:
as_while = False as_while = False
info = {}
info = {} info['tap_array'] = []
info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
info['tap_array'] = [] info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes]) info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes])
info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes])
info['mit_mot_out_slices'] = [] info['mit_mot_out_slices'] = []
info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes]) info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes])
info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes]) info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes])
info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes]) info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes])
info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes]) info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes])
info['truncate_gradient'] = nodes[0].op.truncate_gradient info['truncate_gradient'] = nodes[0].op.truncate_gradient
info['name'] = '&'.join([nd.op.name for nd in nodes]) info['name'] = '&'.join([nd.op.name for nd in nodes])
info['mode'] = nodes[0].op.mode info['mode'] = nodes[0].op.mode
info['inplace'] = False info['inplace'] = False
info['gpu'] = False info['gpu'] = False
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = nodes[0].op.profile info['profile'] = nodes[0].op.profile
inner_ins = []
inner_ins = [] outer_ins = []
outer_ins = []
inner_outs = [] inner_outs = []
outer_outs = [] outer_outs = []
...@@ -844,57 +840,56 @@ class ScanMerge(gof.Optimizer): ...@@ -844,57 +840,56 @@ class ScanMerge(gof.Optimizer):
k.name += str(suffix) k.name += str(suffix)
return ls return ls
for idx,nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Seq # Seq
inner_ins += rename(nd.op.inner_seqs(),idx) inner_ins += rename(nd.op.inner_seqs(), idx)
outer_ins += rename(nd.op.outer_seqs(nd),idx) outer_ins += rename(nd.op.outer_seqs(nd), idx)
for idx,nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitMot # MitMot
inner_ins += rename(nd.op.inner_mitmot(),idx) inner_ins += rename(nd.op.inner_mitmot(), idx)
inner_outs += nd.op.inner_mitmot_outs() inner_outs += nd.op.inner_mitmot_outs()
info['tap_array'] += nd.op.mitmot_taps() info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps() info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd),idx) outer_ins += rename(nd.op.outer_mitmot(nd), idx)
outer_outs += nd.op.outer_mitmot_outs(nd) outer_outs += nd.op.outer_mitmot_outs(nd)
for idx,nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitSot # MitSot
inner_ins += rename(nd.op.inner_mitsot(),idx) inner_ins += rename(nd.op.inner_mitsot(), idx)
inner_outs += nd.op.inner_mitsot_outs() inner_outs += nd.op.inner_mitsot_outs()
info['tap_array'] += nd.op.mitsot_taps() info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd),idx) outer_ins += rename(nd.op.outer_mitsot(nd), idx)
outer_outs += nd.op.outer_mitsot_outs(nd) outer_outs += nd.op.outer_mitsot_outs(nd)
for idx,nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# SitSot # SitSot
inner_ins += rename(nd.op.inner_sitsot(),idx) inner_ins += rename(nd.op.inner_sitsot(), idx)
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)] info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs() inner_outs += nd.op.inner_sitsot_outs()
outer_ins += rename(nd.op.outer_sitsot(nd),idx) outer_ins += rename(nd.op.outer_sitsot(nd), idx)
outer_outs += nd.op.outer_sitsot_outs(nd) outer_outs += nd.op.outer_sitsot_outs(nd)
for idx,nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
inner_ins += rename(nd.op.inner_shared(),idx) inner_ins += rename(nd.op.inner_shared(), idx)
outer_ins += rename(nd.op.outer_shared(nd),idx) outer_ins += rename(nd.op.outer_shared(nd), idx)
for idx,nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# NitSot # NitSot
inner_outs += nd.op.inner_nitsot_outs() inner_outs += nd.op.inner_nitsot_outs()
outer_ins += rename(nd.op.outer_nitsot(nd),idx) outer_ins += rename(nd.op.outer_nitsot(nd), idx)
outer_outs += nd.op.outer_nitsot_outs(nd) outer_outs += nd.op.outer_nitsot_outs(nd)
for idx,nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
outer_outs += nd.op.outer_shared_outs(nd) outer_outs += nd.op.outer_shared_outs(nd)
inner_outs += nd.op.inner_shared_outs() inner_outs += nd.op.inner_shared_outs()
for idx,nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Non Seqs # Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(),idx) inner_ins += rename(nd.op.inner_non_seqs(), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd),idx) outer_ins += rename(nd.op.outer_non_seqs(nd), idx)
# Add back the number of steps # Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins outer_ins = [nodes[0].inputs[0]] + outer_ins
...@@ -913,8 +908,6 @@ class ScanMerge(gof.Optimizer): ...@@ -913,8 +908,6 @@ class ScanMerge(gof.Optimizer):
return zip(outer_outs, new_outs) return zip(outer_outs, new_outs)
def belongs_to_set(self, node, set_nodes): def belongs_to_set(self, node, set_nodes):
""" """
This function checks if node `node` belongs to `set_nodes`, in the This function checks if node `node` belongs to `set_nodes`, in the
...@@ -934,7 +927,6 @@ class ScanMerge(gof.Optimizer): ...@@ -934,7 +927,6 @@ class ScanMerge(gof.Optimizer):
except TypeError: except TypeError:
pass pass
rep_nsteps = rep.inputs[0] rep_nsteps = rep.inputs[0]
try: try:
rep_nsteps = int(get_constant_value(rep_nsteps)) rep_nsteps = int(get_constant_value(rep_nsteps))
...@@ -959,11 +951,9 @@ class ScanMerge(gof.Optimizer): ...@@ -959,11 +951,9 @@ class ScanMerge(gof.Optimizer):
rep.op.inputs) rep.op.inputs)
return same_cond and (nsteps == rep_nsteps) and can_add return same_cond and (nsteps == rep_nsteps) and can_add
def apply(self, env): def apply(self, env):
# Collect all scan nodes ordered according to toposort # Collect all scan nodes ordered according to toposort
scan_nodes = [ nd for nd in env.toposort() scan_nodes = [nd for nd in env.toposort()
if isinstance(nd.op, scan_op.Scan)] if isinstance(nd.op, scan_op.Scan)]
# All sets of possibly mergeable nodes # All sets of possibly mergeable nodes
...@@ -971,7 +961,7 @@ class ScanMerge(gof.Optimizer): ...@@ -971,7 +961,7 @@ class ScanMerge(gof.Optimizer):
for nd in scan_nodes: for nd in scan_nodes:
belongs_to_set_idx = -1 belongs_to_set_idx = -1
for pos,subset in enumerate(all_sets): for pos, subset in enumerate(all_sets):
if self.belongs_to_set(nd, subset): if self.belongs_to_set(nd, subset):
assert belongs_to_set_idx == -1 assert belongs_to_set_idx == -1
belongs_to_set_idx = pos belongs_to_set_idx = pos
...@@ -984,7 +974,7 @@ class ScanMerge(gof.Optimizer): ...@@ -984,7 +974,7 @@ class ScanMerge(gof.Optimizer):
for subset in all_sets: for subset in all_sets:
if len(subset) > 1: if len(subset) > 1:
proposal = self.merge(subset) proposal = self.merge(subset)
env.replace_all_validate(proposal, reason = 'scan_merge') env.replace_all_validate(proposal, reason='scan_merge')
# after const merge but before stabilize so that we can have identity # after const merge but before stabilize so that we can have identity
...@@ -996,23 +986,27 @@ scan_seqopt.register('scanOp_merge', ...@@ -996,23 +986,27 @@ scan_seqopt.register('scanOp_merge',
'fast_run', 'fast_run',
'scan') 'scan')
def has_duplicates(l): def has_duplicates(l):
"""returns true if l has any duplicates (according to __eq__).""" """returns true if l has any duplicates (according to __eq__)."""
return len(set(l)) < len(l) return len(set(l)) < len(l)
def make_equiv(lo, li): def make_equiv(lo, li):
"""builds a dictionary of equivalences between inner inputs based on the equivalence of their corresponding outer inputs.""" """builds a dictionary of equivalences between inner inputs based on
the equivalence of their corresponding outer inputs."""
seeno = {} seeno = {}
left = [] left = []
right = [] right = []
for o, i in zip(lo, li): for o, i in zip(lo, li):
if o in seeno: if o in seeno:
left += [i] left += [i]
right += [o] right += [o]
else: else:
seeno[o] = i seeno[o] = i
return left, right return left, right
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def scan_merge_inouts(node): def scan_merge_inouts(node):
if not isinstance(node.op, scan_op.Scan): if not isinstance(node.op, scan_op.Scan):
...@@ -1072,58 +1066,68 @@ def scan_merge_inouts(node): ...@@ -1072,58 +1066,68 @@ def scan_merge_inouts(node):
na = a na = a
# start again # start again
left = [] left = []
right = [] right = []
if has_duplicates(na.outer_in_shared): if has_duplicates(na.outer_in_shared):
_left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared) _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared)
left += _left left += _left
right += _right right += _right
if has_duplicates(na.outer_in_sit_sot): if has_duplicates(na.outer_in_sit_sot):
_left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot) _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot)
left += _left left += _left
right += _right right += _right
if has_duplicates(na.outer_in_mit_mot): if has_duplicates(na.outer_in_mit_mot):
seen = {} seen = {}
for omm, imm, _sl in zip(na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices): for omm, imm, _sl in zip(na.outer_in_mit_mot,
na.inner_in_mit_mot, na.mit_mot_in_slices):
sl = tuple(_sl) sl = tuple(_sl)
if (omm, sl) in seen: if (omm, sl) in seen:
simm = seen[(omm, sl)] simm = seen[(omm, sl)]
left += imm left += imm
right += simm right += simm
else: else:
seen[(omm, sl)] = imm seen[(omm, sl)] = imm
if has_duplicates(na.outer_in_mit_sot): if has_duplicates(na.outer_in_mit_sot):
seen = {} seen = {}
for oms, ims, _sl in zip(na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices): for oms, ims, _sl in zip(na.outer_in_mit_sot,
na.inner_in_mit_sot,
na.mit_sot_in_slices):
sl = tuple(_sl) sl = tuple(_sl)
if (oms, sl) in seen: if (oms, sl) in seen:
sims = seen[(oms, sl)] sims = seen[(oms, sl)]
left += ims left += ims
right += sims right += sims
else: else:
seen[(oms, sl)] = ims seen[(oms, sl)] = ims
def map_out(i, o, seen): def map_out(i, o, seen):
for si, so in seen: for si, so in seen:
if equal_computations([i], [si],left, right): if equal_computations([i], [si], left, right):
return so return so
seen.append((i, o)) seen.append((i, o))
return o return o
seen = [] seen = []
na.outer_out_nit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_nit_sot, na.outer_out_nit_sot)] na.outer_out_nit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_nit_sot,
na.outer_out_nit_sot)]
seen = [] seen = []
na.outer_out_sit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_sit_sot, na.outer_out_sit_sot)] na.outer_out_sit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_sit_sot,
na.outer_out_sit_sot)]
seen = [] seen = []
na.outer_out_mit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_mit_sot, na.outer_out_mit_sot)] na.outer_out_mit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_mit_sot,
na.outer_out_mit_sot)]
seen = [] seen = []
new_outer_out_mit_mot = [] new_outer_out_mit_mot = []
for imm, omm, osl in zip(na.inner_out_mit_mot, na.outer_out_mit_mot, na.mit_mot_out_slices): for imm, omm, osl in zip(na.inner_out_mit_mot,
na.outer_out_mit_mot, na.mit_mot_out_slices):
for simm, somm, sosl in seen: for simm, somm, sosl in seen:
if osl == sosl and equal_computations(imm, simm, left, right): if osl == sosl and equal_computations(imm, simm, left, right):
new_outer_out_mit_mot.append(somm) new_outer_out_mit_mot.append(somm)
...@@ -1136,7 +1140,7 @@ def scan_merge_inouts(node): ...@@ -1136,7 +1140,7 @@ def scan_merge_inouts(node):
return na.outer_outputs return na.outer_outputs
scan_seqopt.register('scanOp_merge_inouts', scan_seqopt.register('scanOp_merge_inouts',
opt.in2out(scan_merge_inouts,ignore_newtrees=True), opt.in2out(scan_merge_inouts, ignore_newtrees=True),
3, 3,
'fast_run', 'fast_run',
'scan') 'scan')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论