提交 fc22aaf9 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

making scan.py PEP8 compatible

上级 80f150b2
...@@ -34,10 +34,10 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``, ...@@ -34,10 +34,10 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
``foldr()``. ``foldr()``.
""" """
__docformat__ = 'restructedtext en' __docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu " __authors__ = ("Razvan Pascanu "
"Frederic Bastien " "Frederic Bastien "
"James Bergstra " "James Bergstra "
"Pascal Lamblin " ) "Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal" __copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>" __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
...@@ -63,16 +63,16 @@ from scan_utils import safe_new, traverse ...@@ -63,16 +63,16 @@ from scan_utils import safe_new, traverse
_logger = logging.getLogger('theano.scan_module.scan') _logger = logging.getLogger('theano.scan_module.scan')
def scan( fn def scan(fn,
, sequences = None sequences=None,
, outputs_info = None outputs_info=None,
, non_sequences = None non_sequences=None,
, n_steps = None n_steps=None,
, truncate_gradient = -1 truncate_gradient=-1,
, go_backwards = False go_backwards=False,
, mode = None mode=None,
, name = None name=None,
, profile = False): profile=False):
""" """
This function constructs and applies a Scan op to the provided This function constructs and applies a Scan op to the provided
arguments. arguments.
...@@ -333,7 +333,7 @@ def scan( fn ...@@ -333,7 +333,7 @@ def scan( fn
''' '''
if x is None: if x is None:
return [] return []
elif not isinstance(x, (list,tuple)): elif not isinstance(x, (list, tuple)):
return [x] return [x]
else: else:
return list(x) return list(x)
...@@ -356,19 +356,19 @@ def scan( fn ...@@ -356,19 +356,19 @@ def scan( fn
# To do that we check here to see the nature of n_steps # To do that we check here to see the nature of n_steps
n_fixed_steps = None n_fixed_steps = None
if isinstance( n_steps, (float,int)): if isinstance(n_steps, (float, int)):
n_fixed_steps = int(n_steps) n_fixed_steps = int(n_steps)
else: else:
try : try:
n_fixed_steps = opt.get_constant_value(n_steps) n_fixed_steps = opt.get_constant_value(n_steps)
except (TypeError, AttributeError): except (TypeError, AttributeError):
n_fixed_steps = None n_fixed_steps = None
# Check n_steps is an int # Check n_steps is an int
if ( hasattr(n_steps,'dtype') and if (hasattr(n_steps, 'dtype') and
str(n_steps.dtype)[:3] not in ('uin','int') ): str(n_steps.dtype)[:3] not in ('uin', 'int')):
raise ValueError(' n_steps must be an int. dtype provided ' raise ValueError(' n_steps must be an int. dtype provided '
'is %s'%n_steps.dtype) 'is %s' % n_steps.dtype)
# compute number of sequences and number of outputs # compute number of sequences and number of outputs
n_seqs = len(seqs) n_seqs = len(seqs)
...@@ -377,11 +377,11 @@ def scan( fn ...@@ -377,11 +377,11 @@ def scan( fn
return_steps = {} return_steps = {}
# wrap sequences in a dictionary if they are not already dictionaries # wrap sequences in a dictionary if they are not already dictionaries
for i in xrange(n_seqs): for i in xrange(n_seqs):
if not isinstance(seqs[i], dict) : if not isinstance(seqs[i], dict):
seqs[i] = dict(input=seqs[i], taps=[0]) seqs[i] = dict(input=seqs[i], taps=[0])
elif seqs[i].get('taps',None): elif seqs[i].get('taps', None):
seqs[i]['taps'] = wrap_into_list(seqs[i]['taps']) seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
elif seqs[i].get('taps',True) is None: elif seqs[i].get('taps', True) is None:
# seqs dictionary does not have the ``taps`` key # seqs dictionary does not have the ``taps`` key
seqs[i]['taps'] = [0] seqs[i]['taps'] = [0]
...@@ -391,30 +391,31 @@ def scan( fn ...@@ -391,30 +391,31 @@ def scan( fn
if isinstance(outs_info[i], dict): if isinstance(outs_info[i], dict):
# DEPRICATED : # DEPRICATED :
if outs_info[i].get('return_steps', None): if outs_info[i].get('return_steps', None):
_logger.warning( ("Using `return_steps` has been depricated." _logger.warning(("Using `return_steps` has been "
" Simply select the entries you need using " "depricated. Simply select the entries you "
" a subtensor. Scan will optimize memory " "need using a subtensor. Scan will optimize "
" consumption, so do not worry about that.")) "memory consumption, so do not worry about "
"that."))
return_steps[i] = outs_info[i]['return_steps'] return_steps[i] = outs_info[i]['return_steps']
# END # END
if not isinstance(outs_info[i], dict): if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1 # by default any output has a tap value of -1
outs_info[i] = dict(initial=outs_info[i], taps = [-1]) outs_info[i] = dict(initial=outs_info[i], taps=[-1])
elif (not outs_info[i].get('initial',None) and elif (not outs_info[i].get('initial', None) and
outs_info[i].get('taps',None)): outs_info[i].get('taps', None)):
# ^ no initial state but taps provided # ^ no initial state but taps provided
raise ValueError( ( 'If you are using slices of an output ' raise ValueError(('If you are using slices of an output '
'you need to provide a initial state ' 'you need to provide a initial state '
'for it'), outs_info[i] ) 'for it'), outs_info[i])
elif (outs_info[i].get('initial',None) and elif (outs_info[i].get('initial', None) and
not outs_info[i].get('taps',None)): not outs_info[i].get('taps', None)):
# ^ initial state but taps not provided # ^ initial state but taps not provided
if outs_info[i].has_key('taps'): if 'taps' in outs_info[i]:
# ^ explicitly provided a None for taps # ^ explicitly provided a None for taps
_logger.warning('Output %s ( index %d) has a initial state ' _logger.warning('Output %s ( index %d) has a initial '
'but taps is explicitly set to None ', 'state but taps is explicitly set to None ',
getattr(outs_info[i]['initial'],'name','None'), getattr(outs_info[i]['initial'], 'name', 'None'),
i) i)
outs_info[i]['taps'] = [-1] outs_info[i]['taps'] = [-1]
else: else:
...@@ -439,7 +440,7 @@ def scan( fn ...@@ -439,7 +440,7 @@ def scan( fn
inner_seqs = [] # Variables passed as inputs to the inner function inner_seqs = [] # Variables passed as inputs to the inner function
inner_slices = [] # Actual slices if scan is removed from the picture inner_slices = [] # Actual slices if scan is removed from the picture
# go through sequences picking up time slices as needed # go through sequences picking up time slices as needed
for i,seq in enumerate(seqs): for i, seq in enumerate(seqs):
# Note that you can have something like no taps for # Note that you can have something like no taps for
# a sequence, though is highly unlikely in practice # a sequence, though is highly unlikely in practice
if 'taps' in seq: if 'taps' in seq:
...@@ -456,31 +457,33 @@ def scan( fn ...@@ -456,31 +457,33 @@ def scan( fn
# If not we need to use copies, that will be replaced at # If not we need to use copies, that will be replaced at
# each frame by the corresponding slice # each frame by the corresponding slice
actual_slice = seq['input'][k-mintap] actual_slice = seq['input'][k - mintap]
_seq_val = tensor.as_tensor_variable(seq['input']) _seq_val = tensor.as_tensor_variable(seq['input'])
_seq_val_slice = _seq_val[k-mintap] _seq_val_slice = _seq_val[k - mintap]
nw_slice = _seq_val_slice.type() nw_slice = _seq_val_slice.type()
# Try to transfer test_value to the new variable # Try to transfer test_value to the new variable
if config.compute_test_value != 'off': if config.compute_test_value != 'off':
try: try:
nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice) nw_slice.tag.test_value = gof.Op._get_test_value(
_seq_val_slice)
except AttributeError, e: except AttributeError, e:
if config.compute_test_value != 'ignore': if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now, # No need to print a warning or raise an error now,
# it will be done when fn will be called. # it will be done when fn will be called.
_logger.info(('Cannot compute test value for the inner ' _logger.info(('Cannot compute test value for '
'function of scan, input value missing %s'), e) 'the inner function of scan, input value '
'missing %s'), e)
# Add names to slices for debugging and pretty printing .. # Add names to slices for debugging and pretty printing ..
# that is if the input already has a name # that is if the input already has a name
if getattr(seq['input'],'name', None) is not None: if getattr(seq['input'], 'name', None) is not None:
if k > 0: if k > 0:
nw_name = seq['input'].name + '[t+%d]'%k nw_name = seq['input'].name + '[t+%d]' % k
elif k == 0: elif k == 0:
nw_name = seq['input'].name + '[t]' nw_name = seq['input'].name + '[t]'
else: else:
nw_name = seq['input'].name + '[t%d]'%k nw_name = seq['input'].name + '[t%d]' % k
nw_slice.name = nw_name nw_slice.name = nw_name
# We cut the sequence such that seq[i] to correspond to # We cut the sequence such that seq[i] to correspond to
...@@ -490,34 +493,30 @@ def scan( fn ...@@ -490,34 +493,30 @@ def scan( fn
else: else:
offset = 0 offset = 0
if maxtap == mintap and maxtap != 0: if maxtap == mintap and maxtap != 0:
nw_seq =seq['input'][:abs(maxtap)] nw_seq = seq['input'][:abs(maxtap)]
elif maxtap -k != 0 : elif maxtap - k != 0:
nw_seq = seq['input'][offset +k -mintap: -(maxtap -k)] nw_seq = seq['input'][offset + k - mintap: -(maxtap - k)]
else: else:
nw_seq = seq['input'][offset +k -mintap: ] nw_seq = seq['input'][offset + k - mintap:]
if go_backwards: if go_backwards:
nw_seq = nw_seq[::-1] nw_seq = nw_seq[::-1]
scan_seqs.append(nw_seq)
inner_seqs.append(nw_slice)
scan_seqs.append( nw_seq ) inner_slices.append(actual_slice)
inner_seqs.append( nw_slice )
inner_slices.append( actual_slice )
n_seqs += 1 n_seqs += 1
# Since we've added all sequences now we need to level them up based on # Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes # n_steps or their different shapes
lengths_vec = [] lengths_vec = []
for seq in scan_seqs: for seq in scan_seqs:
lengths_vec.append( seq.shape[0] ) lengths_vec.append(seq.shape[0])
if not scan_utils.isNaN_or_Inf_or_None(n_steps): if not scan_utils.isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered # ^ N_steps should also be considered
lengths_vec.append( tensor.as_tensor(n_steps) ) lengths_vec.append(tensor.as_tensor(n_steps))
if len(lengths_vec) == 0:
if len(lengths_vec) == 0 :
# ^ No information about the number of steps # ^ No information about the number of steps
raise ValueError(' No information about the number of steps ' raise ValueError(' No information about the number of steps '
'provided. Either provide a value for ' 'provided. Either provide a value for '
...@@ -534,11 +533,12 @@ def scan( fn ...@@ -534,11 +533,12 @@ def scan( fn
actual_n_steps = tensor.as_tensor(n_steps) actual_n_steps = tensor.as_tensor(n_steps)
# Add names -- it helps a lot when debugging # Add names -- it helps a lot when debugging
for (nw_seq, seq) in zip(scan_seqs, seqs): for (nw_seq, seq) in zip(scan_seqs, seqs):
if getattr(seq['input'],'name', None) is not None: if getattr(seq['input'], 'name', None) is not None:
nw_seq.name = seq['input'].name + '[%d:]'%k nw_seq.name = seq['input'].name + '[%d:]' % k
scan_seqs = [ seq[:actual_n_steps] for seq in scan_seqs] scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
# Conventions : # Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided # mit_mot = multiple input taps, multiple output taps ( only provided
# by the gradient function ) # by the gradient function )
...@@ -546,7 +546,6 @@ def scan( fn ...@@ -546,7 +546,6 @@ def scan( fn
# sit_sot = single input tap, single output tap (t + 0) # sit_sot = single input tap, single output tap (t + 0)
# nit_sot = no input tap, single output tap (t + 0) # nit_sot = no input tap, single output tap (t + 0)
# MIT_MOT -- not provided by the user only by the grad function # MIT_MOT -- not provided by the user only by the grad function
n_mit_mot = 0 n_mit_mot = 0
n_mit_mot_outs = 0 n_mit_mot_outs = 0
...@@ -556,8 +555,6 @@ def scan( fn ...@@ -556,8 +555,6 @@ def scan( fn
mit_mot_out_slices = [] mit_mot_out_slices = []
mit_mot_rightOrder = [] mit_mot_rightOrder = []
# SIT_SOT -- provided by the user # SIT_SOT -- provided by the user
n_mit_sot = 0 n_mit_sot = 0
mit_sot_scan_inputs = [] mit_sot_scan_inputs = []
...@@ -576,9 +573,8 @@ def scan( fn ...@@ -576,9 +573,8 @@ def scan( fn
sit_sot_return_steps = {} sit_sot_return_steps = {}
sit_sot_rightOrder = [] sit_sot_rightOrder = []
# go through outputs picking up time slices as needed # go through outputs picking up time slices as needed
for i,init_out in enumerate(outs_info): for i, init_out in enumerate(outs_info):
# Note that our convention dictates that if an output uses # Note that our convention dictates that if an output uses
# just the previous time step, as a initial state we will only # just the previous time step, as a initial state we will only
# provide a tensor of the same dimension as one time step; This # provide a tensor of the same dimension as one time step; This
...@@ -602,11 +598,12 @@ def scan( fn ...@@ -602,11 +598,12 @@ def scan( fn
if config.compute_test_value != 'ignore': if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now, # No need to print a warning or raise an error now,
# it will be done when fn will be called. # it will be done when fn will be called.
_logger.info(('Cannot compute test value for the inner ' _logger.info(('Cannot compute test value for the '
'function of scan, input value missing %s'), e) 'inner function of scan, input value missing %s'),
e)
if getattr(init_out['initial'],'name', None) is not None: if getattr(init_out['initial'], 'name', None) is not None:
arg.name = init_out['initial'].name+'[t-1]' arg.name = init_out['initial'].name + '[t-1]'
# We need now to allocate space for storing the output and copy # We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function # the initial state over. We do this using the expand function
...@@ -614,117 +611,119 @@ def scan( fn ...@@ -614,117 +611,119 @@ def scan( fn
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
scan_utils.expand( scan_utils.expand(
tensor.unbroadcast( tensor.unbroadcast(
tensor.shape_padleft(actual_arg), 0) tensor.shape_padleft(actual_arg), 0),
, actual_n_steps actual_n_steps
) ) ))
sit_sot_inner_slices.append(actual_arg) sit_sot_inner_slices.append(actual_arg)
if i in return_steps: if i in return_steps:
sit_sot_return_steps[n_sit_sot] = return_steps[i] sit_sot_return_steps[n_sit_sot] = return_steps[i]
sit_sot_inner_inputs.append( arg ) sit_sot_inner_inputs.append(arg)
sit_sot_rightOrder.append( i ) sit_sot_rightOrder.append(i)
n_sit_sot += 1 n_sit_sot += 1
elif init_out.get('taps',None): elif init_out.get('taps', None):
if numpy.any(numpy.array(init_out.get('taps',[])) > 0): if numpy.any(numpy.array(init_out.get('taps', [])) > 0):
# Make sure we do not have requests for future values of a # Make sure we do not have requests for future values of a
# sequence we can not provide such values # sequence we can not provide such values
raise ValueError('Can not use future taps of outputs' raise ValueError('Can not use future taps of outputs',
, init_out) init_out)
# go through the taps # go through the taps
mintap = abs(numpy.min(init_out['taps'])) mintap = abs(numpy.min(init_out['taps']))
mit_sot_tap_array.append( init_out['taps'] ) mit_sot_tap_array.append(init_out['taps'])
idx_offset = abs(numpy.min(init_out['taps'])) idx_offset = abs(numpy.min(init_out['taps']))
# Sequence # Sequence
mit_sot_scan_inputs.append( mit_sot_scan_inputs.append(
scan_utils.expand( init_out['initial'][:mintap] scan_utils.expand(init_out['initial'][:mintap],
, actual_n_steps) ) actual_n_steps))
if i in return_steps: if i in return_steps:
mit_sot_return_steps[n_mit_sot] = return_steps[i] mit_sot_return_steps[n_mit_sot] = return_steps[i]
mit_sot_rightOrder.append( i ) mit_sot_rightOrder.append(i)
n_mit_sot += 1 n_mit_sot += 1
for k in init_out['taps']: for k in init_out['taps']:
# create a new slice # create a new slice
actual_nw_slice = init_out['initial'][k+mintap] actual_nw_slice = init_out['initial'][k + mintap]
_init_out_var = tensor.as_tensor_variable(init_out['initial']) _init_out_var = tensor.as_tensor_variable(init_out['initial'])
_init_out_var_slice = _init_out_var[k+mintap] _init_out_var_slice = _init_out_var[k + mintap]
nw_slice = _init_out_var_slice.type() nw_slice = _init_out_var_slice.type()
# Try to transfer test_value to the new variable # Try to transfer test_value to the new variable
if config.compute_test_value != 'off': if config.compute_test_value != 'off':
try: try:
nw_slice.tag.test_value = gof.Op._get_test_value(_init_out_var_slice) nw_slice.tag.test_value = gof.Op._get_test_value(
_init_out_var_slice)
except AttributeError, e: except AttributeError, e:
if config.compute_test_value != 'ignore': if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now, # No need to print a warning or raise an error now,
# it will be done when fn will be called. # it will be done when fn will be called.
_logger.info(('Cannot compute test value for the inner ' _logger.info(('Cannot compute test value for '
'function of scan, input value missing. %s'), e) 'the inner function of scan, input value '
'missing. %s'), e)
# give it a name or debugging and pretty printing # give it a name or debugging and pretty printing
if getattr(init_out['initial'],'name', None) is not None: if getattr(init_out['initial'], 'name', None) is not None:
if k > 0: if k > 0:
nw_slice.name = ( init_out['initial'].name + nw_slice.name = (init_out['initial'].name +
'[t+%d]'%k ) '[t+%d]' % k)
elif k == 0: elif k == 0:
nw_slice.name = init_out['initial'].name + '[t]' nw_slice.name = init_out['initial'].name + '[t]'
else: else:
nw_slice.name = ( init_out['initial'].name + nw_slice.name = (init_out['initial'].name +
'[t%d]'%k ) '[t%d]' % k)
mit_sot_inner_inputs.append( nw_slice ) mit_sot_inner_inputs.append(nw_slice)
mit_sot_inner_slices.append( actual_nw_slice ) mit_sot_inner_slices.append(actual_nw_slice)
#NOTE: there is another case, in which we do not want to provide #NOTE: there is another case, in which we do not want to provide
# any previous value of the output to the inner function (i.e. # any previous value of the output to the inner function (i.e.
# a map); in that case we do not have to do anything .. # a map); in that case we do not have to do anything ..
# Re-order args # Re-order args
max_mit_sot = numpy.max( [-1] + mit_sot_rightOrder ) + 1 max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1
max_sit_sot = numpy.max( [-1] + sit_sot_rightOrder ) + 1 max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1
n_elems = numpy.max( [ max_mit_sot, max_sit_sot ] ) n_elems = numpy.max([max_mit_sot, max_sit_sot])
_ordered_args = [[] for x in xrange(n_elems)] _ordered_args = [[] for x in xrange(n_elems)]
offset = 0 offset = 0
for idx in xrange(n_mit_sot): for idx in xrange(n_mit_sot):
n_inputs = len(mit_sot_tap_array[idx]) n_inputs = len(mit_sot_tap_array[idx])
if n_fixed_steps in [1,-1]: if n_fixed_steps in [1, -1]:
_ordered_args[mit_sot_rightOrder[idx]] = \ _ordered_args[mit_sot_rightOrder[idx]] = \
mit_sot_inner_slices[offset:offset+n_inputs] mit_sot_inner_slices[offset:offset + n_inputs]
else: else:
_ordered_args[mit_sot_rightOrder[idx]] = \ _ordered_args[mit_sot_rightOrder[idx]] = \
mit_sot_inner_inputs[offset:offset+n_inputs] mit_sot_inner_inputs[offset:offset + n_inputs]
offset += n_inputs offset += n_inputs
for idx in xrange(n_sit_sot): for idx in xrange(n_sit_sot):
if n_fixed_steps in [1,-1]: if n_fixed_steps in [1, -1]:
_ordered_args[sit_sot_rightOrder[idx]] = \ _ordered_args[sit_sot_rightOrder[idx]] = \
[ sit_sot_inner_slices[idx] ] [sit_sot_inner_slices[idx]]
else: else:
_ordered_args[sit_sot_rightOrder[idx]] = \ _ordered_args[sit_sot_rightOrder[idx]] = \
[ sit_sot_inner_inputs[idx] ] [sit_sot_inner_inputs[idx]]
ordered_args = [] ordered_args = []
for ls in _ordered_args: for ls in _ordered_args:
ordered_args += ls ordered_args += ls
if n_fixed_steps in [1,-1]: if n_fixed_steps in [1, -1]:
args = (inner_slices + args = (inner_slices +
ordered_args + ordered_args +
non_seqs ) non_seqs)
else: else:
args = ( inner_seqs + args = (inner_seqs +
ordered_args + ordered_args +
non_seqs ) non_seqs)
# add only the non-shared variables and non-constants to the arguments of the dummy # add only the non-shared variables and non-constants to the arguments of
# function [ a function should not get shared variables or constants as input ] # the dummy function [ a function should not get shared variables or
# constants as input ]
dummy_args = [arg for arg in args dummy_args = [arg for arg in args
if (not isinstance(arg, SharedVariable) and if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant) )] not isinstance(arg, tensor.Constant))]
# when we apply the lambda expression we get a mixture of update rules # when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated # and outputs that needs to be separated
condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args)) condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args))
if condition is not None: if condition is not None:
as_while = True as_while = True
...@@ -734,14 +733,13 @@ def scan( fn ...@@ -734,14 +733,13 @@ def scan( fn
### Step 3. Check if we actually need scan and remove it if we don't ### Step 3. Check if we actually need scan and remove it if we don't
## ##
if n_fixed_steps in [1, -1]: if n_fixed_steps in [1, -1]:
# We do not need to use the scan op anymore, so we can just return # We do not need to use the scan op anymore, so we can just return
# the outputs and updates we have # the outputs and updates we have
if condition is not None: if condition is not None:
_logger.warning( ('When the number of steps is fixed and equal to 1,' _logger.warning(('When the number of steps is fixed and equal '
' the provided stopping condition, ', str(condition), 'to 1, the provided stopping condition, ',
' is ignored')) str(condition), ' is ignored'))
for pos, inner_out in enumerate(outputs): for pos, inner_out in enumerate(outputs):
# we need to see if we need to pad our sequences with an # we need to see if we need to pad our sequences with an
...@@ -750,16 +748,15 @@ def scan( fn ...@@ -750,16 +748,15 @@ def scan( fn
# then, if we return the output as given by the innner function # then, if we return the output as given by the innner function
# this will represent only a slice and it will have one # this will represent only a slice and it will have one
# dimension less. # dimension less.
if ( isinstance(inner_out.type, tensor.TensorType) and if (isinstance(inner_out.type, tensor.TensorType) and
return_steps.get(pos, 0) != 1): return_steps.get(pos, 0) != 1):
outputs[pos] = tensor.unbroadcast( outputs[pos] = tensor.unbroadcast(
tensor.shape_padleft(inner_out),0) tensor.shape_padleft(inner_out), 0)
if len(outputs) == 1: if len(outputs) == 1:
outputs = outputs[0] outputs = outputs[0]
return (outputs, updates) return (outputs, updates)
## ##
### Step 4. Compile the dummy function ### Step 4. Compile the dummy function
## ##
...@@ -779,11 +776,11 @@ def scan( fn ...@@ -779,11 +776,11 @@ def scan( fn
replace=dict(zip(non_seqs, replace=dict(zip(non_seqs,
fake_nonseqs))) fake_nonseqs)))
all_inputs = itertools.ifilter( all_inputs = itertools.ifilter(
lambda x: ( isinstance(x, gof.Variable) and lambda x: (isinstance(x, gof.Variable) and
not isinstance(x, SharedVariable) and not isinstance(x, SharedVariable) and
not isinstance(x, gof.Constant) ), not isinstance(x, gof.Constant)),
gof.graph.inputs( fake_outputs) ) gof.graph.inputs(fake_outputs))
extra_inputs = filter( lambda x: x not in args + fake_nonseqs, extra_inputs = filter(lambda x: x not in args + fake_nonseqs,
all_inputs) all_inputs)
non_seqs += extra_inputs non_seqs += extra_inputs
## Note we do not use all_inputs directly since the order of variables ## Note we do not use all_inputs directly since the order of variables
...@@ -793,12 +790,11 @@ def scan( fn ...@@ -793,12 +790,11 @@ def scan( fn
dummy_outs = outputs dummy_outs = outputs
if condition is not None: if condition is not None:
dummy_outs.append(condition) dummy_outs.append(condition)
dummy_f = function( dummy_args dummy_f = function(dummy_args,
, dummy_outs dummy_outs,
, updates = updates updates=updates,
, mode = compile.mode.Mode(linker='py', mode=compile.mode.Mode(linker='py',
optimizer=None) ) optimizer=None))
## ##
### Step 5. Re-arange inputs of scan into a more strict order ### Step 5. Re-arange inputs of scan into a more strict order
...@@ -807,7 +803,6 @@ def scan( fn ...@@ -807,7 +803,6 @@ def scan( fn
## Step 5.0 Check the outputs of the dummy function to see if they ## Step 5.0 Check the outputs of the dummy function to see if they
## match with user provided data ## match with user provided data
# if the number of outputs to the function does not match the number of # if the number of outputs to the function does not match the number of
# assumed outputs until now (provided by the user) there can be # assumed outputs until now (provided by the user) there can be
# only one explanation: No information is provided for any of the # only one explanation: No information is provided for any of the
...@@ -815,7 +810,7 @@ def scan( fn ...@@ -815,7 +810,7 @@ def scan( fn
tmp_dummy_f_outs = len(dummy_f.maker.outputs) tmp_dummy_f_outs = len(dummy_f.maker.outputs)
if as_while: if as_while:
tmp_dummy_f_outs -= 1 tmp_dummy_f_outs -= 1
if not ( tmp_dummy_f_outs == n_outs or outs_info == []): if not (tmp_dummy_f_outs == n_outs or outs_info == []):
raise ValueError('Please provide None as output_info for ' raise ValueError('Please provide None as output_info for '
'any output that does not feed back into ' 'any output that does not feed back into '
'scan (i.e. it behaves like a map) ') 'scan (i.e. it behaves like a map) ')
...@@ -824,21 +819,18 @@ def scan( fn ...@@ -824,21 +819,18 @@ def scan( fn
n_outs = len(dummy_f.maker.outputs) n_outs = len(dummy_f.maker.outputs)
if as_while: if as_while:
n_outs = n_outs - 1 n_outs = n_outs - 1
outs_info = [ dict() for x in xrange(n_outs) ] outs_info = [dict() for x in xrange(n_outs)]
## Step 5.1 Outputs with taps different then -1 ## Step 5.1 Outputs with taps different then -1
for i, out in enumerate(outs_info): for i, out in enumerate(outs_info):
if 'taps' in out and out['taps'] != [-1]: if 'taps' in out and out['taps'] != [-1]:
mit_sot_inner_outputs.append( outputs[i]) mit_sot_inner_outputs.append(outputs[i])
## Step 5.2 Outputs with tap equal to -1 ## Step 5.2 Outputs with tap equal to -1
for i, out in enumerate(outs_info): for i, out in enumerate(outs_info):
if 'taps' in out and out['taps'] == [-1]: if 'taps' in out and out['taps'] == [-1]:
sit_sot_inner_outputs.append( outputs[i] ) sit_sot_inner_outputs.append(outputs[i])
## Step 5.3 Outputs that correspond to update rules of shared variables ## Step 5.3 Outputs that correspond to update rules of shared variables
givens = {} givens = {}
...@@ -849,11 +841,11 @@ def scan( fn ...@@ -849,11 +841,11 @@ def scan( fn
for input in dummy_f.maker.expanded_inputs: for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update: if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
if getattr(input.variable,'name', None) is not None: if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy' new_var.name = input.variable.name + '_copy'
shared_inner_inputs.append( new_var ) shared_inner_inputs.append(new_var)
shared_scan_inputs.append( input.variable ) shared_scan_inputs.append(input.variable)
shared_inner_outputs.append( input.update ) shared_inner_outputs.append(input.update)
givens[input.variable] = new_var givens[input.variable] = new_var
n_shared_outs += 1 n_shared_outs += 1
...@@ -862,57 +854,56 @@ def scan( fn ...@@ -862,57 +854,56 @@ def scan( fn
nit_sot_inner_outputs = [] nit_sot_inner_outputs = []
nit_sot_return_steps = {} nit_sot_return_steps = {}
nit_sot_rightOrder = [] nit_sot_rightOrder = []
for i,out in enumerate(outs_info): for i, out in enumerate(outs_info):
if not 'taps' in out: if not 'taps' in out:
nit_sot_inner_outputs.append( outputs[i] ) nit_sot_inner_outputs.append(outputs[i])
if i in return_steps: if i in return_steps:
nit_sot_return_steps[n_nit_sot] = return_steps[i] nit_sot_return_steps[n_nit_sot] = return_steps[i]
nit_sot_rightOrder.append( i ) nit_sot_rightOrder.append(i)
n_nit_sot += 1 n_nit_sot += 1
## Step 5.5 all other arguments including extra inputs ## Step 5.5 all other arguments including extra inputs
other_scan_args = [] other_scan_args = []
other_inner_args = [] other_inner_args = []
other_scan_args += [ arg for arg in non_seqs other_scan_args += [arg for arg in non_seqs
if (not isinstance(arg, SharedVariable) and if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant))] not isinstance(arg, tensor.Constant))]
## Step 5.6 all shared variables with no update rules ## Step 5.6 all shared variables with no update rules
other_inner_args += [ safe_new(arg,'_copy') for arg in non_seqs other_inner_args += [safe_new(arg, '_copy') for arg in non_seqs
if (not isinstance(arg, SharedVariable) and if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant))] not isinstance(arg, tensor.Constant))]
givens.update( dict( zip(other_scan_args, other_inner_args) )) givens.update(dict(zip(other_scan_args, other_inner_args)))
other_shared_scan_args = [ arg.variable for arg other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs in dummy_f.maker.expanded_inputs
if ( isinstance(arg.variable, SharedVariable) and if (isinstance(arg.variable, SharedVariable) and
not arg.update) ] not arg.update)]
other_shared_inner_args = [ safe_new(arg.variable, '_copy') for arg other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs in dummy_f.maker.expanded_inputs
if ( isinstance(arg.variable, SharedVariable) and if (isinstance(arg.variable, SharedVariable) and
not arg.update) ] not arg.update)]
givens.update( dict( zip( other_shared_scan_args, givens.update(dict(zip(other_shared_scan_args,
other_shared_inner_args) ) ) other_shared_inner_args)))
## ##
### Step 6. Re-order the outputs and clone them replacing things ### Step 6. Re-order the outputs and clone them replacing things
### using the givens ### using the givens
## ##
inner_inputs = ( inner_seqs + inner_inputs = (inner_seqs +
mit_mot_inner_inputs + mit_mot_inner_inputs +
mit_sot_inner_inputs + mit_sot_inner_inputs +
sit_sot_inner_inputs + sit_sot_inner_inputs +
shared_inner_inputs + shared_inner_inputs +
other_shared_inner_args + other_shared_inner_args +
other_inner_args ) other_inner_args)
inner_outs = ( mit_mot_inner_outputs + inner_outs = (mit_mot_inner_outputs +
mit_sot_inner_outputs + mit_sot_inner_outputs +
sit_sot_inner_outputs + sit_sot_inner_outputs +
nit_sot_inner_outputs + nit_sot_inner_outputs +
shared_inner_outputs ) shared_inner_outputs)
if condition is not None: if condition is not None:
inner_outs.append(condition) inner_outs.append(condition)
# Cuda is imported here, instead of being imported on top of the file # Cuda is imported here, instead of being imported on top of the file
...@@ -927,18 +918,17 @@ def scan( fn ...@@ -927,18 +918,17 @@ def scan( fn
# variables are put on GPU right aways >:| , # variables are put on GPU right aways >:| ,
new_givens = {} new_givens = {}
for w, w_copy in givens.iteritems():
for w,w_copy in givens.iteritems():
if (isinstance(w.type, cuda.CudaNdarrayType) if (isinstance(w.type, cuda.CudaNdarrayType)
and isinstance(w_copy.type, tensor.TensorType)): and isinstance(w_copy.type, tensor.TensorType)):
for o in inner_outs: for o in inner_outs:
new_givens = traverse(o,w,w_copy, new_givens) new_givens = traverse(o, w, w_copy, new_givens)
else: else:
new_givens[w] = w_copy new_givens[w] = w_copy
else: else:
new_givens = givens new_givens = givens
new_outs = scan_utils.clone(inner_outs, replace = new_givens) new_outs = scan_utils.clone(inner_outs, replace=new_givens)
## ##
### Step 7. Create the Scan Op ### Step 7. Create the Scan Op
...@@ -964,22 +954,22 @@ def scan( fn ...@@ -964,22 +954,22 @@ def scan( fn
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = profile info['profile'] = profile
local_op = scan_op.Scan( inner_inputs, new_outs, info ) local_op = scan_op.Scan(inner_inputs, new_outs, info)
## ##
### Step 8. Compute the outputs using the scan op ### Step 8. Compute the outputs using the scan op
## ##
_scan_inputs = ( scan_seqs + _scan_inputs = (scan_seqs +
mit_mot_scan_inputs + mit_mot_scan_inputs +
mit_sot_scan_inputs + mit_sot_scan_inputs +
sit_sot_scan_inputs + sit_sot_scan_inputs +
shared_scan_inputs + shared_scan_inputs +
[ actual_n_steps for x in xrange(n_nit_sot) ] + [actual_n_steps for x in xrange(n_nit_sot)] +
other_shared_scan_args + other_shared_scan_args +
other_scan_args ) other_scan_args)
scan_inputs = [] scan_inputs = []
for arg in [actual_n_steps]+ _scan_inputs: for arg in [actual_n_steps] + _scan_inputs:
try: try:
arg = tensor.as_tensor_variable(arg) arg = tensor.as_tensor_variable(arg)
except TypeError: except TypeError:
...@@ -987,8 +977,8 @@ def scan( fn ...@@ -987,8 +977,8 @@ def scan( fn
# to make sure no input is a cuda ndarrays # to make sure no input is a cuda ndarrays
pass pass
scan_inputs += [arg] scan_inputs += [arg]
scan_outs = local_op(* scan_inputs ) scan_outs = local_op(*scan_inputs)
if type(scan_outs) not in (list,tuple): if type(scan_outs) not in (list, tuple):
scan_outs = [scan_outs] scan_outs = [scan_outs]
## ##
### Step 9. Figure out which outs are update rules for shared variables ### Step 9. Figure out which outs are update rules for shared variables
...@@ -996,54 +986,56 @@ def scan( fn ...@@ -996,54 +986,56 @@ def scan( fn
## ##
update_map = Updates() update_map = Updates()
def remove_dimensions( outs, steps_return, offsets = None):
def remove_dimensions(outs, steps_return, offsets=None):
out_ls = [] out_ls = []
for idx, out in enumerate(outs): for idx, out in enumerate(outs):
if idx in steps_return: if idx in steps_return:
if steps_return[idx] > 1: if steps_return[idx] > 1:
out_ls.append( out[-steps_return[idx]:] ) out_ls.append(out[-steps_return[idx]:])
else: else:
out_ls.append( out[-1] ) out_ls.append(out[-1])
else: else:
if offsets is None: if offsets is None:
out_ls.append( out ) out_ls.append(out)
else: else:
out_ls.append( out[offsets[idx]:] ) out_ls.append(out[offsets[idx]:])
return out_ls return out_ls
offset = n_mit_mot offset = n_mit_mot
offsets = [ abs(numpy.min(x)) for x in mit_sot_tap_array ] offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array]
mit_sot_outs = remove_dimensions( mit_sot_outs = remove_dimensions(
scan_outs[offset:offset+n_mit_sot] scan_outs[offset:offset + n_mit_sot],
, mit_sot_return_steps mit_sot_return_steps,
, offsets ) offsets)
offset += n_mit_sot offset += n_mit_sot
offsets = [ 1 for x in xrange(n_sit_sot) ] offsets = [1 for x in xrange(n_sit_sot)]
sit_sot_outs = remove_dimensions( sit_sot_outs = remove_dimensions(
scan_outs[offset:offset+n_sit_sot] scan_outs[offset:offset + n_sit_sot],
, sit_sot_return_steps sit_sot_return_steps,
, offsets ) offsets)
offset += n_sit_sot offset += n_sit_sot
nit_sot_outs = remove_dimensions( nit_sot_outs = remove_dimensions(
scan_outs[offset:offset+n_nit_sot] scan_outs[offset:offset + n_nit_sot],
, nit_sot_return_steps ) nit_sot_return_steps)
offset += n_nit_sot offset += n_nit_sot
for idx, update_rule in enumerate(scan_outs[offset:offset+n_shared_outs]): for idx, update_rule in enumerate(
scan_outs[offset:offset + n_shared_outs]):
update_map[shared_scan_inputs[idx]] = update_rule update_map[shared_scan_inputs[idx]] = update_rule
_scan_out_list = ( mit_sot_outs + _scan_out_list = (mit_sot_outs +
sit_sot_outs + sit_sot_outs +
nit_sot_outs ) nit_sot_outs)
# Step 10. I need to reorder the outputs to be in the order expected by # Step 10. I need to reorder the outputs to be in the order expected by
# the user # the user
rightOrder = ( mit_sot_rightOrder + rightOrder = (mit_sot_rightOrder +
sit_sot_rightOrder + sit_sot_rightOrder +
nit_sot_rightOrder ) nit_sot_rightOrder)
scan_out_list = [None]*len(rightOrder) scan_out_list = [None] * len(rightOrder)
for idx,pos in enumerate(rightOrder): for idx, pos in enumerate(rightOrder):
scan_out_list[pos] = _scan_out_list[idx] scan_out_list[pos] = _scan_out_list[idx]
if len(scan_out_list) == 1: if len(scan_out_list) == 1:
scan_out_list = scan_out_list[0] scan_out_list = scan_out_list[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论