提交 f4b29ab6 authored 作者: nouiz's avatar nouiz

Merge pull request #166 from pascanur/scan_grad_nsteps

Fix a bug of grad reported by Michael Forbes
......@@ -34,10 +34,10 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
``foldr()``.
"""
__docformat__ = 'restructedtext en'
__authors__ = ( "Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin " )
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
......@@ -63,16 +63,16 @@ from scan_utils import safe_new, traverse
_logger = logging.getLogger('theano.scan_module.scan')
def scan( fn
, sequences = None
, outputs_info = None
, non_sequences = None
, n_steps = None
, truncate_gradient = -1
, go_backwards = False
, mode = None
, name = None
, profile = False):
def scan(fn,
sequences=None,
outputs_info=None,
non_sequences=None,
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
mode=None,
name=None,
profile=False):
"""
This function constructs and applies a Scan op to the provided
arguments.
......@@ -333,12 +333,12 @@ def scan( fn
'''
if x is None:
return []
elif not isinstance(x, (list,tuple)):
elif not isinstance(x, (list, tuple)):
return [x]
else:
return list(x)
seqs = wrap_into_list(sequences)
seqs = wrap_into_list(sequences)
outs_info = wrap_into_list(outputs_info)
# Make sure we get rid of numpy arrays or ints or anything like that
......@@ -356,19 +356,19 @@ def scan( fn
# To do that we check here to see the nature of n_steps
n_fixed_steps = None
if isinstance( n_steps, (float,int)):
if isinstance(n_steps, (float, int)):
n_fixed_steps = int(n_steps)
else:
try :
try:
n_fixed_steps = opt.get_constant_value(n_steps)
except (TypeError, AttributeError):
n_fixed_steps = None
# Check n_steps is an int
if ( hasattr(n_steps,'dtype') and
str(n_steps.dtype)[:3] not in ('uin','int') ):
if (hasattr(n_steps, 'dtype') and
str(n_steps.dtype)[:3] not in ('uin', 'int')):
raise ValueError(' n_steps must be an int. dtype provided '
'is %s'%n_steps.dtype)
'is %s' % n_steps.dtype)
# compute number of sequences and number of outputs
n_seqs = len(seqs)
......@@ -377,11 +377,11 @@ def scan( fn
return_steps = {}
# wrap sequences in a dictionary if they are not already dictionaries
for i in xrange(n_seqs):
if not isinstance(seqs[i], dict) :
if not isinstance(seqs[i], dict):
seqs[i] = dict(input=seqs[i], taps=[0])
elif seqs[i].get('taps',None):
elif seqs[i].get('taps', None):
seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
elif seqs[i].get('taps',True) is None:
elif seqs[i].get('taps', True) is None:
# seqs dictionary does not have the ``taps`` key
seqs[i]['taps'] = [0]
......@@ -391,30 +391,31 @@ def scan( fn
if isinstance(outs_info[i], dict):
# DEPRICATED :
if outs_info[i].get('return_steps', None):
_logger.warning( ("Using `return_steps` has been depricated."
" Simply select the entries you need using "
" a subtensor. Scan will optimize memory "
" consumption, so do not worry about that."))
_logger.warning(("Using `return_steps` has been "
"depricated. Simply select the entries you "
"need using a subtensor. Scan will optimize "
"memory consumption, so do not worry about "
"that."))
return_steps[i] = outs_info[i]['return_steps']
# END
if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1
outs_info[i] = dict(initial=outs_info[i], taps = [-1])
elif (not outs_info[i].get('initial',None) and
outs_info[i].get('taps',None)):
outs_info[i] = dict(initial=outs_info[i], taps=[-1])
elif (not outs_info[i].get('initial', None) and
outs_info[i].get('taps', None)):
# ^ no initial state but taps provided
raise ValueError( ( 'If you are using slices of an output '
'you need to provide a initial state '
'for it'), outs_info[i] )
elif (outs_info[i].get('initial',None) and
not outs_info[i].get('taps',None)):
raise ValueError(('If you are using slices of an output '
'you need to provide a initial state '
'for it'), outs_info[i])
elif (outs_info[i].get('initial', None) and
not outs_info[i].get('taps', None)):
# ^ initial state but taps not provided
if outs_info[i].has_key('taps'):
if 'taps' in outs_info[i]:
# ^ explicitly provided a None for taps
_logger.warning('Output %s ( index %d) has a initial state '
'but taps is explicitly set to None ',
getattr(outs_info[i]['initial'],'name','None'),
_logger.warning('Output %s ( index %d) has a initial '
'state but taps is explicitly set to None ',
getattr(outs_info[i]['initial'], 'name', 'None'),
i)
outs_info[i]['taps'] = [-1]
else:
......@@ -434,12 +435,12 @@ def scan( fn
# and to construct a new and complete list of inputs and
# outputs
n_seqs = 0
scan_seqs = [] # Variables passed as inputs to the scan op
inner_seqs = [] # Variables passed as inputs to the inner function
inner_slices = [] # Actual slices if scan is removed from the picture
n_seqs = 0
scan_seqs = [] # Variables passed as inputs to the scan op
inner_seqs = [] # Variables passed as inputs to the inner function
inner_slices = [] # Actual slices if scan is removed from the picture
# go through sequences picking up time slices as needed
for i,seq in enumerate(seqs):
for i, seq in enumerate(seqs):
# Note that you can have something like no taps for
# a sequence, though is highly unlikely in practice
if 'taps' in seq:
......@@ -456,31 +457,33 @@ def scan( fn
# If not we need to use copies, that will be replaced at
# each frame by the corresponding slice
actual_slice = seq['input'][k-mintap]
actual_slice = seq['input'][k - mintap]
_seq_val = tensor.as_tensor_variable(seq['input'])
_seq_val_slice = _seq_val[k-mintap]
_seq_val_slice = _seq_val[k - mintap]
nw_slice = _seq_val_slice.type()
# Try to transfer test_value to the new variable
if config.compute_test_value != 'off':
try:
nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice)
nw_slice.tag.test_value = gof.Op._get_test_value(
_seq_val_slice)
except AttributeError, e:
if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(('Cannot compute test value for the inner '
'function of scan, input value missing %s'), e)
_logger.info(('Cannot compute test value for '
'the inner function of scan, input value '
'missing %s'), e)
# Add names to slices for debugging and pretty printing ..
# that is if the input already has a name
if getattr(seq['input'],'name', None) is not None:
if getattr(seq['input'], 'name', None) is not None:
if k > 0:
nw_name = seq['input'].name + '[t+%d]'%k
nw_name = seq['input'].name + '[t+%d]' % k
elif k == 0:
nw_name = seq['input'].name + '[t]'
else:
nw_name = seq['input'].name + '[t%d]'%k
nw_name = seq['input'].name + '[t%d]' % k
nw_slice.name = nw_name
# We cut the sequence such that seq[i] to correspond to
......@@ -490,34 +493,30 @@ def scan( fn
else:
offset = 0
if maxtap == mintap and maxtap != 0:
nw_seq =seq['input'][:abs(maxtap)]
elif maxtap -k != 0 :
nw_seq = seq['input'][offset +k -mintap: -(maxtap -k)]
nw_seq = seq['input'][:abs(maxtap)]
elif maxtap - k != 0:
nw_seq = seq['input'][offset + k - mintap: -(maxtap - k)]
else:
nw_seq = seq['input'][offset +k -mintap: ]
nw_seq = seq['input'][offset + k - mintap:]
if go_backwards:
nw_seq = nw_seq[::-1]
scan_seqs.append( nw_seq )
inner_seqs.append( nw_slice )
inner_slices.append( actual_slice )
scan_seqs.append(nw_seq)
inner_seqs.append(nw_slice)
inner_slices.append(actual_slice)
n_seqs += 1
# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
lengths_vec = []
for seq in scan_seqs:
lengths_vec.append( seq.shape[0] )
lengths_vec.append(seq.shape[0])
if not scan_utils.isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered
lengths_vec.append( tensor.as_tensor(n_steps) )
lengths_vec.append(tensor.as_tensor(n_steps))
if len(lengths_vec) == 0 :
if len(lengths_vec) == 0:
# ^ No information about the number of steps
raise ValueError(' No information about the number of steps '
'provided. Either provide a value for '
......@@ -534,10 +533,12 @@ def scan( fn
actual_n_steps = tensor.as_tensor(n_steps)
# Add names -- it helps a lot when debugging
for (nw_seq, seq) in zip(scan_seqs, seqs):
if getattr(seq['input'],'name', None) is not None:
nw_seq.name = seq['input'].name + '[%d:]'%k
if getattr(seq['input'], 'name', None) is not None:
nw_seq.name = seq['input'].name + '[%d:]' % k
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
# Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided
# by the gradient function )
......@@ -545,39 +546,35 @@ def scan( fn
# sit_sot = single input tap, single output tap (t + 0)
# nit_sot = no input tap, single output tap (t + 0)
# MIT_MOT -- not provided by the user only by the grad function
n_mit_mot = 0
n_mit_mot_outs = 0
mit_mot_scan_inputs = []
mit_mot_inner_inputs = []
n_mit_mot = 0
n_mit_mot_outs = 0
mit_mot_scan_inputs = []
mit_mot_inner_inputs = []
mit_mot_inner_outputs = []
mit_mot_out_slices = []
mit_mot_rightOrder = []
mit_mot_out_slices = []
mit_mot_rightOrder = []
# SIT_SOT -- provided by the user
n_mit_sot = 0
mit_sot_scan_inputs = []
mit_sot_inner_inputs = []
mit_sot_inner_slices = []
n_mit_sot = 0
mit_sot_scan_inputs = []
mit_sot_inner_inputs = []
mit_sot_inner_slices = []
mit_sot_inner_outputs = []
mit_sot_return_steps = {}
mit_sot_tap_array = []
mit_sot_rightOrder = []
n_sit_sot = 0
sit_sot_scan_inputs = []
sit_sot_inner_inputs = []
sit_sot_inner_slices = []
mit_sot_return_steps = {}
mit_sot_tap_array = []
mit_sot_rightOrder = []
n_sit_sot = 0
sit_sot_scan_inputs = []
sit_sot_inner_inputs = []
sit_sot_inner_slices = []
sit_sot_inner_outputs = []
sit_sot_return_steps = {}
sit_sot_rightOrder = []
sit_sot_return_steps = {}
sit_sot_rightOrder = []
# go through outputs picking up time slices as needed
for i,init_out in enumerate(outs_info):
for i, init_out in enumerate(outs_info):
# Note that our convention dictates that if an output uses
# just the previous time step, as a initial state we will only
# provide a tensor of the same dimension as one time step; This
......@@ -601,11 +598,12 @@ def scan( fn
if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(('Cannot compute test value for the inner '
'function of scan, input value missing %s'), e)
_logger.info(('Cannot compute test value for the '
'inner function of scan, input value missing %s'),
e)
if getattr(init_out['initial'],'name', None) is not None:
arg.name = init_out['initial'].name+'[t-1]'
if getattr(init_out['initial'], 'name', None) is not None:
arg.name = init_out['initial'].name + '[t-1]'
# We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function
......@@ -613,117 +611,119 @@ def scan( fn
sit_sot_scan_inputs.append(
scan_utils.expand(
tensor.unbroadcast(
tensor.shape_padleft(actual_arg), 0)
, actual_n_steps
) )
tensor.shape_padleft(actual_arg), 0),
actual_n_steps
))
sit_sot_inner_slices.append(actual_arg)
if i in return_steps:
sit_sot_return_steps[n_sit_sot] = return_steps[i]
sit_sot_inner_inputs.append( arg )
sit_sot_rightOrder.append( i )
sit_sot_inner_inputs.append(arg)
sit_sot_rightOrder.append(i)
n_sit_sot += 1
elif init_out.get('taps',None):
elif init_out.get('taps', None):
if numpy.any(numpy.array(init_out.get('taps',[])) > 0):
if numpy.any(numpy.array(init_out.get('taps', [])) > 0):
# Make sure we do not have requests for future values of a
# sequence we can not provide such values
raise ValueError('Can not use future taps of outputs'
, init_out)
raise ValueError('Can not use future taps of outputs',
init_out)
# go through the taps
mintap = abs(numpy.min(init_out['taps']))
mit_sot_tap_array.append( init_out['taps'] )
mit_sot_tap_array.append(init_out['taps'])
idx_offset = abs(numpy.min(init_out['taps']))
# Sequence
mit_sot_scan_inputs.append(
scan_utils.expand( init_out['initial'][:mintap]
, actual_n_steps) )
scan_utils.expand(init_out['initial'][:mintap],
actual_n_steps))
if i in return_steps:
mit_sot_return_steps[n_mit_sot] = return_steps[i]
mit_sot_rightOrder.append( i )
mit_sot_rightOrder.append(i)
n_mit_sot += 1
for k in init_out['taps']:
# create a new slice
actual_nw_slice = init_out['initial'][k+mintap]
actual_nw_slice = init_out['initial'][k + mintap]
_init_out_var = tensor.as_tensor_variable(init_out['initial'])
_init_out_var_slice = _init_out_var[k+mintap]
_init_out_var_slice = _init_out_var[k + mintap]
nw_slice = _init_out_var_slice.type()
# Try to transfer test_value to the new variable
if config.compute_test_value != 'off':
try:
nw_slice.tag.test_value = gof.Op._get_test_value(_init_out_var_slice)
nw_slice.tag.test_value = gof.Op._get_test_value(
_init_out_var_slice)
except AttributeError, e:
if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(('Cannot compute test value for the inner '
'function of scan, input value missing. %s'), e)
_logger.info(('Cannot compute test value for '
'the inner function of scan, input value '
'missing. %s'), e)
# give it a name or debugging and pretty printing
if getattr(init_out['initial'],'name', None) is not None:
if getattr(init_out['initial'], 'name', None) is not None:
if k > 0:
nw_slice.name = ( init_out['initial'].name +
'[t+%d]'%k )
nw_slice.name = (init_out['initial'].name +
'[t+%d]' % k)
elif k == 0:
nw_slice.name = init_out['initial'].name + '[t]'
else:
nw_slice.name = ( init_out['initial'].name +
'[t%d]'%k )
mit_sot_inner_inputs.append( nw_slice )
mit_sot_inner_slices.append( actual_nw_slice )
nw_slice.name = (init_out['initial'].name +
'[t%d]' % k)
mit_sot_inner_inputs.append(nw_slice)
mit_sot_inner_slices.append(actual_nw_slice)
#NOTE: there is another case, in which we do not want to provide
# any previous value of the output to the inner function (i.e.
# a map); in that case we do not have to do anything ..
# Re-order args
max_mit_sot = numpy.max( [-1] + mit_sot_rightOrder ) + 1
max_sit_sot = numpy.max( [-1] + sit_sot_rightOrder ) + 1
n_elems = numpy.max( [ max_mit_sot, max_sit_sot ] )
max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1
max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1
n_elems = numpy.max([max_mit_sot, max_sit_sot])
_ordered_args = [[] for x in xrange(n_elems)]
offset = 0
for idx in xrange(n_mit_sot):
n_inputs = len(mit_sot_tap_array[idx])
if n_fixed_steps in [1,-1]:
if n_fixed_steps in [1, -1]:
_ordered_args[mit_sot_rightOrder[idx]] = \
mit_sot_inner_slices[offset:offset+n_inputs]
mit_sot_inner_slices[offset:offset + n_inputs]
else:
_ordered_args[mit_sot_rightOrder[idx]] = \
mit_sot_inner_inputs[offset:offset+n_inputs]
mit_sot_inner_inputs[offset:offset + n_inputs]
offset += n_inputs
for idx in xrange(n_sit_sot):
if n_fixed_steps in [1,-1]:
if n_fixed_steps in [1, -1]:
_ordered_args[sit_sot_rightOrder[idx]] = \
[ sit_sot_inner_slices[idx] ]
[sit_sot_inner_slices[idx]]
else:
_ordered_args[sit_sot_rightOrder[idx]] = \
[ sit_sot_inner_inputs[idx] ]
[sit_sot_inner_inputs[idx]]
ordered_args = []
for ls in _ordered_args:
ordered_args += ls
if n_fixed_steps in [1,-1]:
if n_fixed_steps in [1, -1]:
args = (inner_slices +
ordered_args +
non_seqs )
non_seqs)
else:
args = ( inner_seqs +
args = (inner_seqs +
ordered_args +
non_seqs )
non_seqs)
# add only the non-shared variables and non-constants to the arguments of the dummy
# function [ a function should not get shared variables or constants as input ]
# add only the non-shared variables and non-constants to the arguments of
# the dummy function [ a function should not get shared variables or
# constants as input ]
dummy_args = [arg for arg in args
if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant) )]
not isinstance(arg, tensor.Constant))]
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args))
if condition is not None:
as_while = True
......@@ -733,14 +733,13 @@ def scan( fn
### Step 3. Check if we actually need scan and remove it if we don't
##
if n_fixed_steps in [1, -1]:
# We do not need to use the scan op anymore, so we can just return
# the outputs and updates we have
if condition is not None:
_logger.warning( ('When the number of steps is fixed and equal to 1,'
' the provided stopping condition, ', str(condition),
' is ignored'))
_logger.warning(('When the number of steps is fixed and equal '
'to 1, the provided stopping condition, ',
str(condition), ' is ignored'))
for pos, inner_out in enumerate(outputs):
# we need to see if we need to pad our sequences with an
......@@ -749,16 +748,15 @@ def scan( fn
# then, if we return the output as given by the innner function
# this will represent only a slice and it will have one
# dimension less.
if ( isinstance(inner_out.type, tensor.TensorType) and
if (isinstance(inner_out.type, tensor.TensorType) and
return_steps.get(pos, 0) != 1):
outputs[pos] = tensor.unbroadcast(
tensor.shape_padleft(inner_out),0)
tensor.shape_padleft(inner_out), 0)
if len(outputs) == 1:
outputs = outputs[0]
return (outputs, updates)
##
### Step 4. Compile the dummy function
##
......@@ -778,11 +776,11 @@ def scan( fn
replace=dict(zip(non_seqs,
fake_nonseqs)))
all_inputs = itertools.ifilter(
lambda x: ( isinstance(x, gof.Variable) and
lambda x: (isinstance(x, gof.Variable) and
not isinstance(x, SharedVariable) and
not isinstance(x, gof.Constant) ),
gof.graph.inputs( fake_outputs) )
extra_inputs = filter( lambda x: x not in args + fake_nonseqs,
not isinstance(x, gof.Constant)),
gof.graph.inputs(fake_outputs))
extra_inputs = filter(lambda x: x not in args + fake_nonseqs,
all_inputs)
non_seqs += extra_inputs
## Note we do not use all_inputs directly since the order of variables
......@@ -792,12 +790,11 @@ def scan( fn
dummy_outs = outputs
if condition is not None:
dummy_outs.append(condition)
dummy_f = function( dummy_args
, dummy_outs
, updates = updates
, mode = compile.mode.Mode(linker='py',
optimizer=None) )
dummy_f = function(dummy_args,
dummy_outs,
updates=updates,
mode=compile.mode.Mode(linker='py',
optimizer=None))
##
### Step 5. Re-arange inputs of scan into a more strict order
......@@ -806,7 +803,6 @@ def scan( fn
## Step 5.0 Check the outputs of the dummy function to see if they
## match with user provided data
# if the number of outputs to the function does not match the number of
# assumed outputs until now (provided by the user) there can be
# only one explanation: No information is provided for any of the
......@@ -814,7 +810,7 @@ def scan( fn
tmp_dummy_f_outs = len(dummy_f.maker.outputs)
if as_while:
tmp_dummy_f_outs -= 1
if not ( tmp_dummy_f_outs == n_outs or outs_info == []):
if not (tmp_dummy_f_outs == n_outs or outs_info == []):
raise ValueError('Please provide None as output_info for '
'any output that does not feed back into '
'scan (i.e. it behaves like a map) ')
......@@ -823,95 +819,91 @@ def scan( fn
n_outs = len(dummy_f.maker.outputs)
if as_while:
n_outs = n_outs - 1
outs_info = [ dict() for x in xrange(n_outs) ]
outs_info = [dict() for x in xrange(n_outs)]
## Step 5.1 Outputs with taps different then -1
for i, out in enumerate(outs_info):
if 'taps' in out and out['taps'] != [-1]:
mit_sot_inner_outputs.append( outputs[i])
mit_sot_inner_outputs.append(outputs[i])
## Step 5.2 Outputs with tap equal to -1
for i, out in enumerate(outs_info):
if 'taps' in out and out['taps'] == [-1]:
sit_sot_inner_outputs.append( outputs[i] )
sit_sot_inner_outputs.append(outputs[i])
## Step 5.3 Outputs that correspond to update rules of shared variables
givens = {}
n_shared_outs = 0
shared_scan_inputs = []
shared_inner_inputs = []
givens = {}
n_shared_outs = 0
shared_scan_inputs = []
shared_inner_inputs = []
shared_inner_outputs = []
for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable)
if getattr(input.variable,'name', None) is not None:
if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy'
shared_inner_inputs.append( new_var )
shared_scan_inputs.append( input.variable )
shared_inner_outputs.append( input.update )
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
givens[input.variable] = new_var
n_shared_outs += 1
## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0
n_nit_sot = 0
nit_sot_inner_outputs = []
nit_sot_return_steps = {}
nit_sot_rightOrder = []
for i,out in enumerate(outs_info):
nit_sot_return_steps = {}
nit_sot_rightOrder = []
for i, out in enumerate(outs_info):
if not 'taps' in out:
nit_sot_inner_outputs.append( outputs[i] )
nit_sot_inner_outputs.append(outputs[i])
if i in return_steps:
nit_sot_return_steps[n_nit_sot] = return_steps[i]
nit_sot_rightOrder.append( i )
nit_sot_rightOrder.append(i)
n_nit_sot += 1
## Step 5.5 all other arguments including extra inputs
other_scan_args = []
other_scan_args = []
other_inner_args = []
other_scan_args += [ arg for arg in non_seqs
other_scan_args += [arg for arg in non_seqs
if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant))]
## Step 5.6 all shared variables with no update rules
other_inner_args += [ safe_new(arg,'_copy') for arg in non_seqs
other_inner_args += [safe_new(arg, '_copy') for arg in non_seqs
if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant))]
givens.update( dict( zip(other_scan_args, other_inner_args) ))
other_shared_scan_args = [ arg.variable for arg
givens.update(dict(zip(other_scan_args, other_inner_args)))
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if ( isinstance(arg.variable, SharedVariable) and
not arg.update) ]
other_shared_inner_args = [ safe_new(arg.variable, '_copy') for arg
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if ( isinstance(arg.variable, SharedVariable) and
not arg.update) ]
givens.update( dict( zip( other_shared_scan_args,
other_shared_inner_args) ) )
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
givens.update(dict(zip(other_shared_scan_args,
other_shared_inner_args)))
##
### Step 6. Re-order the outputs and clone them replacing things
### using the givens
##
inner_inputs = ( inner_seqs +
mit_mot_inner_inputs +
mit_sot_inner_inputs +
sit_sot_inner_inputs +
shared_inner_inputs +
inner_inputs = (inner_seqs +
mit_mot_inner_inputs +
mit_sot_inner_inputs +
sit_sot_inner_inputs +
shared_inner_inputs +
other_shared_inner_args +
other_inner_args )
other_inner_args)
inner_outs = ( mit_mot_inner_outputs +
mit_sot_inner_outputs +
sit_sot_inner_outputs +
nit_sot_inner_outputs +
shared_inner_outputs )
inner_outs = (mit_mot_inner_outputs +
mit_sot_inner_outputs +
sit_sot_inner_outputs +
nit_sot_inner_outputs +
shared_inner_outputs)
if condition is not None:
inner_outs.append(condition)
# Cuda is imported here, instead of being imported on top of the file
......@@ -926,59 +918,58 @@ def scan( fn
# variables are put on GPU right aways >:| ,
new_givens = {}
for w,w_copy in givens.iteritems():
for w, w_copy in givens.iteritems():
if (isinstance(w.type, cuda.CudaNdarrayType)
and isinstance(w_copy.type, tensor.TensorType)):
for o in inner_outs:
new_givens = traverse(o,w,w_copy, new_givens)
new_givens = traverse(o, w, w_copy, new_givens)
else:
new_givens[w] = w_copy
else:
new_givens = givens
new_outs = scan_utils.clone(inner_outs, replace = new_givens)
new_outs = scan_utils.clone(inner_outs, replace=new_givens)
##
### Step 7. Create the Scan Op
##
tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)]
info = {}
info = {}
info['tap_array'] = tap_array
info['n_seqs'] = n_seqs
info['n_mit_mot'] = n_mit_mot
info['n_mit_mot_outs'] = n_mit_mot_outs
info['tap_array'] = tap_array
info['n_seqs'] = n_seqs
info['n_mit_mot'] = n_mit_mot
info['n_mit_mot_outs'] = n_mit_mot_outs
info['mit_mot_out_slices'] = mit_mot_out_slices
info['n_mit_sot'] = n_mit_sot
info['n_sit_sot'] = n_sit_sot
info['n_shared_outs'] = n_shared_outs
info['n_nit_sot'] = n_nit_sot
info['truncate_gradient'] = truncate_gradient
info['name'] = name
info['mode'] = mode
info['inplace'] = False
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = profile
local_op = scan_op.Scan( inner_inputs, new_outs, info )
info['n_mit_sot'] = n_mit_sot
info['n_sit_sot'] = n_sit_sot
info['n_shared_outs'] = n_shared_outs
info['n_nit_sot'] = n_nit_sot
info['truncate_gradient'] = truncate_gradient
info['name'] = name
info['mode'] = mode
info['inplace'] = False
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = profile
local_op = scan_op.Scan(inner_inputs, new_outs, info)
##
### Step 8. Compute the outputs using the scan op
##
_scan_inputs = ( scan_seqs +
mit_mot_scan_inputs +
mit_sot_scan_inputs +
sit_sot_scan_inputs +
shared_scan_inputs +
[ actual_n_steps for x in xrange(n_nit_sot) ] +
other_shared_scan_args +
other_scan_args )
_scan_inputs = (scan_seqs +
mit_mot_scan_inputs +
mit_sot_scan_inputs +
sit_sot_scan_inputs +
shared_scan_inputs +
[actual_n_steps for x in xrange(n_nit_sot)] +
other_shared_scan_args +
other_scan_args)
scan_inputs = []
for arg in [actual_n_steps]+ _scan_inputs:
for arg in [actual_n_steps] + _scan_inputs:
try:
arg = tensor.as_tensor_variable(arg)
except TypeError:
......@@ -986,8 +977,8 @@ def scan( fn
# to make sure no input is a cuda ndarrays
pass
scan_inputs += [arg]
scan_outs = local_op(* scan_inputs )
if type(scan_outs) not in (list,tuple):
scan_outs = local_op(*scan_inputs)
if type(scan_outs) not in (list, tuple):
scan_outs = [scan_outs]
##
### Step 9. Figure out which outs are update rules for shared variables
......@@ -995,55 +986,57 @@ def scan( fn
##
update_map = Updates()
def remove_dimensions( outs, steps_return, offsets = None):
def remove_dimensions(outs, steps_return, offsets=None):
out_ls = []
for idx, out in enumerate(outs):
if idx in steps_return:
if steps_return[idx] > 1:
out_ls.append( out[-steps_return[idx]:] )
out_ls.append(out[-steps_return[idx]:])
else:
out_ls.append( out[-1] )
out_ls.append(out[-1])
else:
if offsets is None:
out_ls.append( out )
out_ls.append(out)
else:
out_ls.append( out[offsets[idx]:] )
out_ls.append(out[offsets[idx]:])
return out_ls
offset = n_mit_mot
offsets = [ abs(numpy.min(x)) for x in mit_sot_tap_array ]
offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array]
mit_sot_outs = remove_dimensions(
scan_outs[offset:offset+n_mit_sot]
, mit_sot_return_steps
, offsets )
scan_outs[offset:offset + n_mit_sot],
mit_sot_return_steps,
offsets)
offset += n_mit_sot
offsets = [ 1 for x in xrange(n_sit_sot) ]
offsets = [1 for x in xrange(n_sit_sot)]
sit_sot_outs = remove_dimensions(
scan_outs[offset:offset+n_sit_sot]
, sit_sot_return_steps
, offsets )
scan_outs[offset:offset + n_sit_sot],
sit_sot_return_steps,
offsets)
offset += n_sit_sot
nit_sot_outs = remove_dimensions(
scan_outs[offset:offset+n_nit_sot]
, nit_sot_return_steps )
scan_outs[offset:offset + n_nit_sot],
nit_sot_return_steps)
offset += n_nit_sot
for idx, update_rule in enumerate(scan_outs[offset:offset+n_shared_outs]):
for idx, update_rule in enumerate(
scan_outs[offset:offset + n_shared_outs]):
update_map[shared_scan_inputs[idx]] = update_rule
_scan_out_list = ( mit_sot_outs +
sit_sot_outs +
nit_sot_outs )
_scan_out_list = (mit_sot_outs +
sit_sot_outs +
nit_sot_outs)
# Step 10. I need to reorder the outputs to be in the order expected by
# the user
rightOrder = ( mit_sot_rightOrder +
sit_sot_rightOrder +
nit_sot_rightOrder )
scan_out_list = [None]*len(rightOrder)
for idx,pos in enumerate(rightOrder):
scan_out_list[pos] = _scan_out_list[idx]
rightOrder = (mit_sot_rightOrder +
sit_sot_rightOrder +
nit_sot_rightOrder)
scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder):
scan_out_list[pos] = _scan_out_list[idx]
if len(scan_out_list) == 1:
scan_out_list = scan_out_list[0]
elif len(scan_out_list) == 0:
......
......@@ -2481,6 +2481,26 @@ class T_Scan(unittest.TestCase):
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 1
def test_grad_multiple_seqs_different_nsteps(self):
# Example provided Michael Forbes
# This test assures that we clip the sequences to n_steps before
# computing the gradient (so that when we reverse them we actually
# get the right values in
c = theano.tensor.vector('c')
x = theano.tensor.scalar('x')
_max_coefficients_supported = 1000
full_range = theano.tensor.arange(_max_coefficients_supported)
components, updates = theano.scan(
fn=lambda coeff, power, free_var: coeff * (free_var ** power),
outputs_info=None,
sequences=[c, full_range],
non_sequences=x)
P = components.sum()
dP = theano.tensor.grad(P, x)
tf = theano.function([c, x], dP)
assert tf([1.0, 2.0, -3.0, 4.0], 2.0) == 38
def test_return_steps(self):
rng = numpy.random.RandomState(utt.fetch_seed())
vW_in2 = asarrayX(rng.uniform(size = (2,), low = -5.,high = 5.))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论