提交 b34cccb8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

some fix for scan when computing gradients wrt to shared variables that where…

some fix for scan when computing gradients wrt to shared variables that where not passed to the inner function (added a new test in test_scan.py
上级 1542aa1e
...@@ -352,13 +352,14 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -352,13 +352,14 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
:param mode: :param mode:
The mode used when compiling the theano function in the Scan op. The mode used when compiling the theano function in the Scan op.
If None will use the config mode. If None, it will use the config mode. If None and the config mode is set to
If None and the config mode is a a profile mode, we will create a new instance profile mode, it we will create a new instance of the ProfileMode in order
to compute correctly the timming. to compute the timming correctly.
Otherwise the time spend in Scan will show up twice in the profiling, once If no new instance is created the time spend in Scan will show up twice in the
as the time taken by scan, and a second time as taken by the individial ops profiling, once as the time taken by scan, and the second time as the time
that scan calls to do a iteration step. taken by the ops inside scan. This will be even worse for multiple cascading
The new profiler instance will be printed when python exits. scans.
The new profiler instance will be printed when python exits.
:rtype: tuple :rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a :return: tuple of the form (outputs, updates); ``outputs`` is either a
...@@ -455,7 +456,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -455,7 +456,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
sequences_taps[i] = seqs[i]['taps'] sequences_taps[i] = seqs[i]['taps']
# wrap outputs info in a dictionary if they are not already # wrap outputs info in a dictionary if they are not already
# in the same pass create a init_outs_taps dictionary and a inplace map # in one and in the same pass create a init_outs_taps dictionary and a inplace map
for i in xrange(n_outs): for i in xrange(n_outs):
if outs_info[i]: if outs_info[i]:
# If output is a dictionary, collect the number of steps the # If output is a dictionary, collect the number of steps the
...@@ -480,25 +481,29 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -480,25 +481,29 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
outs_info[i] = dict(initial=outs_info[i], taps = [-1]) outs_info[i] = dict(initial=outs_info[i], taps = [-1])
# if there is no initial state but there are taps # if there is no initial state but there are taps
# then return an error because it makes no sense # then return an error because it makes no sense
elif (not outs_info[i].get('initial',None)) and(outs_info[i].get('taps',None)): elif (not outs_info[i].get('initial',None)) and \
(outs_info[i].get('taps',None)):
raise ValueError('If you are using slices of an output you need to '\ raise ValueError('If you are using slices of an output you need to '\
'provide a initial state for it', outs_info[i]) 'provide a initial state for it', outs_info[i])
# if there is an intial state but no tap, we will add the default value for # if there is an intial state but no tap, we will add the default value
# taps, namely [-1] ( previous value); not that this will happen even though # for taps, namely [-1] ( previous value); not that this will happen
# you have provided for taps the value None, which is a bit strange (why would # even though you have provided for taps the value None, which is a bit
# one provide an initial state but tell scan not to use it ? ), just that # strange (why would one provide an initial state but tell scan not to
# in that case we will throw in a warning message pointing out this inconsistency # use it ? ), just that in that case we will throw in a warning message
elif outs_info[i].get('initial',None) and ( not outs_info[i].get('taps',None)): # pointing out this inconsistency
elif outs_info[i].get('initial',None) and \
( not outs_info[i].get('taps',None)):
if outs_info[i].has_key('taps'): if outs_info[i].has_key('taps'):
warning('You are providing a initial state for an output, but yet tell scan' warning('You are providing a initial state for an output and then '
'not to use it. Why? Scan will overwrite this setting and use the previous' 'tell scan not to use it. Why? Scan will overwrite this setting'
'value of the provided initial state. If this is not what you wanted, check' ' and use the previous value of the provided initial state. If'
'your code and do not provide the initial state') ' this is not what you wanted, check your code and do not '
'provide the initial state')
outs_info[i]['taps'] = [-1] outs_info[i]['taps'] = [-1]
else: else:
# if the output is a None then replace it with an empty dictionary for easing # if the output is a None then replace it with an empty dictionary for
# up dealing with this case later one ( we can directly call .has_key and things # easing up dealing with this case later one ( we can directly call .has_key
# like this # and things like this
outs_info[i] = dict() outs_info[i] = dict()
store_steps += [0] store_steps += [0]
...@@ -507,16 +512,16 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -507,16 +512,16 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# is how the Scan Op expects this information, separeted from the variables # is how the Scan Op expects this information, separeted from the variables
outputs_taps[i] = outs_info[i]['taps'] outputs_taps[i] = outs_info[i]['taps']
if outs_info[i].get('inplace', None): if outs_info[i].get('inplace', None):
# The same is true for the inplace info; it has to go into a separate dictionary # The same is true for the inplace info; it has to go into a separate
# based on index; Note that the input we're replacing should also come as an # dictionary based on index; Note that the input we're replacing should also
# index, therefore we have to look for it here # come as an index, therefore we have to look for it at this point
found = None found = None
for k in xrange(n_seqs): for k in xrange(n_seqs):
if seqs[k].get('input', None) == outs_info[i].get('inplace',None): if seqs[k].get('input', None) == outs_info[i].get('inplace',None):
found = k found = k
if found != None: if found != None:
# NOTE : inplace_map is identical to destroy_map, i.e. it tells what output # NOTE : inplace_map is identical to destroy_map, i.e. it tells what
# is computed inplace of what input !! # output is computed inplace of what input !!
inplace_map[i] = found inplace_map[i] = found
else: else:
raise ValueError('Asked to compute in place of a non-input variable',\ raise ValueError('Asked to compute in place of a non-input variable',\
...@@ -528,8 +533,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -528,8 +533,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# function to detect shared variables and their updates # function to detect shared variables and their updates
# and to construct a new and complete list of inputs and outputs # and to construct a new and complete list of inputs and outputs
args = [] # list of arguments args = [] # list of arguments
dummy_notshared_ins = 0 # number of arguments corresponding to input sequences dummy_notshared_ins = 0 # number of arguments corresponding to input seqs
dummy_notshared_init_outs = 0 # number of arguments corresponding to output sequences dummy_notshared_init_outs = 0 # number of arguments corresponding to output seqs
slice_to_seqs = [] # for each slice index of the corresponding input slice_to_seqs = [] # for each slice index of the corresponding input
# go through sequences picking up time slices as needed # go through sequences picking up time slices as needed
for i,seq in enumerate(seqs): for i,seq in enumerate(seqs):
...@@ -626,15 +631,15 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -626,15 +631,15 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# anything .. # anything ..
# remove shared variables from the non sequences list # remove shared variables from the non sequences list
# such that we can compile the function ( the user has the option to add them when writing # such that we can compile the function ( the user has the option to add them when
# scan, because in some situations this might make the code more readable) # writing scan, because in some situations this might make the code more readable)
# Also duplicate the list of non sequences arguments to contain copies of the non-shared # Also duplicate the list of non sequences arguments to contain copies of the
# inputs ( this fixes the case when one of this inputs has a default update attached to it # non-shared inputs ( this fixes the case when one of this inputs has a default
# that belongs to some shared random stream ). # update attached to it that belongs to some shared random stream )
# #
# Note : In that case, scan assumes that you do not want to draw new numbers at every call ( you # Note : In that case, scan assumes that you do not want to draw new numbers at
# would have made the internal function do that explicitly if you wanted to) but rather to # every call ( you would have made the internal function do that explicitly
# use that initial draw as a matrix of values # if you wanted to) but rather to use that initial draw as a matrix of values
new_non_seqs = [] new_non_seqs = []
notshared_other_args = [] notshared_other_args = []
notshared_other_args_copies = [] notshared_other_args_copies = []
...@@ -642,6 +647,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -642,6 +647,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if not isinstance(non_seq, SharedVariable): if not isinstance(non_seq, SharedVariable):
if n_fixed_steps not in [-1,1]: if n_fixed_steps not in [-1,1]:
non_seq_copy = non_seq.type() non_seq_copy = non_seq.type()
if non_seq.name :
non_seq_copy.name = non_seq.name + '_copy'
else: else:
non_seq_copy = non_seq non_seq_copy = non_seq
notshared_other_args += [non_seq] notshared_other_args += [non_seq]
...@@ -788,13 +795,16 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -788,13 +795,16 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# Skip the slices that we've added to the inner_fn which will be the first elements # Skip the slices that we've added to the inner_fn which will be the first elements
# of f.maker.epanded_inputs and which we know that are not shared # of f.maker.epanded_inputs and which we know that are not shared
fromIdx = dummy_notshared_ins + dummy_notshared_init_outs fromIdx = dummy_notshared_ins + dummy_notshared_init_outs
copy_map = {}
for input in dummy_f.maker.expanded_inputs[fromIdx:] : for input in dummy_f.maker.expanded_inputs[fromIdx:] :
# If input is a shared variable that gets updated, then # If input is a shared variable that gets updated, then
# this shared variable will be an output of our inner function # this shared variable will be an output of our inner function
if isinstance(input.variable, SharedVariable) and input.update: if isinstance(input.variable, SharedVariable) and input.update:
# Create a copy of it # Create a copy of it
new_var = input.variable.type() new_var = input.variable.type()
if input.variable.name:
new_var.name = input.variable.name + '_copy'
copy_map[new_var] = input.variable
inner_fn_inputs.append(new_var) inner_fn_inputs.append(new_var)
# add it to the slices at the end # add it to the slices at the end
slice_to_seqs += [ n_extended_outs ] slice_to_seqs += [ n_extended_outs ]
...@@ -818,9 +828,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -818,9 +828,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# make sure that we do not add the same shared variable twice # make sure that we do not add the same shared variable twice
if isinstance(input.variable, SharedVariable) and not input.update: if isinstance(input.variable, SharedVariable) and not input.update:
shared_non_seqs += [input.variable] shared_non_seqs += [input.variable]
inner_fn_inputs += [input.variable.type() ] new_var = input.variable.type()
if input.variable.name:
new_var.name = input.variable.name + '_copy'
inner_fn_inputs += [new_var]
slice_to_seqs += [ n_extended_outs] slice_to_seqs += [ n_extended_outs]
givens[input.variable] = inner_fn_inputs[-1] givens[input.variable] = inner_fn_inputs[-1]
copy_map[inner_fn_inputs[-1]] = input.variable
elif not isinstance(input.variable, SharedVariable): elif not isinstance(input.variable, SharedVariable):
# also add the normal tensor that are non sequences at the # also add the normal tensor that are non sequences at the
# end of the inputs intertwingled with the shared variables # end of the inputs intertwingled with the shared variables
...@@ -849,7 +863,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -849,7 +863,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# a gradient # a gradient
n_outs, inner_fn_notshared_ins_idx, inner_fn_shared_ins_idx, n_outs, inner_fn_notshared_ins_idx, inner_fn_shared_ins_idx,
go_backwards, store_steps, return_steps, mode, name = name ) go_backwards, store_steps, return_steps, mode, name = name )
# Shortcut for attaching this property to the Scan op
local_op.copy_map = copy_map
# Call the object on the input sequences, initial values for outs, # Call the object on the input sequences, initial values for outs,
# and non sequences # and non sequences
for seq in seqs : for seq in seqs :
...@@ -1200,7 +1215,8 @@ class Scan(Op): ...@@ -1200,7 +1215,8 @@ class Scan(Op):
else: else:
# check if you are using past value .. through in a warning and do not # check if you are using past value .. through in a warning and do not
# work inplace # work inplace
if inplace_map.has_key(i) and seqs_taps.has_key(inplace_map[i]) and seqs_taps[inplace_map[i]] < 0: if inplace_map.has_key(i) and seqs_taps.has_key(inplace_map[i]) and\
seqs_taps[inplace_map[i]] < 0:
warning('Can not work inplace because of past values') warning('Can not work inplace because of past values')
if self.store_steps[i] == 1 : if self.store_steps[i] == 1 :
y+= [ None ] y+= [ None ]
...@@ -1262,10 +1278,10 @@ class Scan(Op): ...@@ -1262,10 +1278,10 @@ class Scan(Op):
else: else:
k = i + sz + tap_value k = i + sz + tap_value
if k < 0: if k < 0:
# past value not provided.. issue a warning and use 0s of the # past value not provided.. issue a warning and use
# correct dtype # 0s of the correct dtype
fn_args += [numpy.zeros(args[j+n_seqs][0].shape, dtype = fn_args += [numpy.zeros(args[j+n_seqs][0].shape, \
args[j+n_sqs][0].dtype)] dtype = args[j+n_sqs][0].dtype)]
warning(('Past value %d for output %d not given in ' warning(('Past value %d for output %d not given in '
'inital out') % (j,tap_value)) 'inital out') % (j,tap_value))
else: else:
...@@ -1280,7 +1296,8 @@ class Scan(Op): ...@@ -1280,7 +1296,8 @@ class Scan(Op):
else: else:
# storing only the last k # storing only the last k
# get what idx we want # get what idx we want
req_idx = (self.idx_store_steps[j] + tap_value + self.store_steps[j]) req_idx = (self.idx_store_steps[j] + tap_value + \
self.store_steps[j])
# we need this modula self.store_steps[j] # we need this modula self.store_steps[j]
req_idx = req_idx % self.store_steps[j] req_idx = req_idx % self.store_steps[j]
fn_args += [y[j][req_idx] ] fn_args += [y[j][req_idx] ]
...@@ -1292,42 +1309,48 @@ class Scan(Op): ...@@ -1292,42 +1309,48 @@ class Scan(Op):
#update outputs #update outputs
for j in xrange(n_outs): for j in xrange(n_outs):
if self.store_steps[j] <1: if self.store_steps[j] <1:
# if you have provided no size for the missing output you might find yourself # if you have provided no size for the missing output you might
# here with a incorect array .. if that happens realocate memory for the # find yourself here with a incorect array .. if that happens
# needed array # realocate memory for the needed array
try : try :
if hasattr(something[j],'dtype') and (y[j].dtype != something[j].dtype) : if hasattr(something[j],'dtype') and (y[j].dtype != \
something[j].dtype) :
raise ValueError('wrong dtype') raise ValueError('wrong dtype')
y[j][i] = something[j] y[j][i] = something[j]
except : except :
y[j]= numpy.empty((n_steps,)+something[j].shape, dtype= something[j].dtype) y[j]= numpy.empty((n_steps,)+something[j].shape, dtype= \
something[j].dtype)
y[j][i] = something[j] y[j][i] = something[j]
elif self.store_steps[j] == 1: elif self.store_steps[j] == 1:
try: try:
if hasattr(something[j],'dtype') and y[j].dtype != something[j].dtype: if hasattr(something[j],'dtype') and y[j].dtype != \
something[j].dtype:
raise ValueError('wrong dtype') raise ValueError('wrong dtype')
y[j] = something[j] y[j] = something[j]
except: except:
y[j] = numpy.empty( something[j].shape, dtype = something[j].dtype) y[j] = numpy.empty( something[j].shape, dtype = \
something[j].dtype)
y[j] = something[j] y[j] = something[j]
else: else:
try: try:
if hasattr(something[j],'dtype') and y[j].dtype != something[j].dtype: if hasattr(something[j],'dtype') and y[j].dtype != \
something[j].dtype:
raise ValueError('worng dtype') raise ValueError('worng dtype')
y[j][self.idx_store_steps[j]] = something[j] y[j][self.idx_store_steps[j]] = something[j]
self.idx_store_steps[j] = (self.idx_store_steps[j] + 1) % self.store_steps[j] self.idx_store_steps[j] = (self.idx_store_steps[j] + 1) %\
self.store_steps[j]
except: except:
y[j] = numpy.empty( (self.store_steps[j],)+something[j].shape, \ y[j] = numpy.empty( (self.store_steps[j],)+something[j].shape, \
dtype = something[j].dtype) dtype = something[j].dtype)
y[j][idx_store_steps[j]] = something[j] y[j][idx_store_steps[j]] = something[j]
self.idx_store_steps[j] = (self.idx_store_steps[j] + 1) % self.store_steps[j] self.idx_store_steps[j] = (self.idx_store_steps[j] + 1) %\
self.store_steps[j]
return y return y
def grad(self, args, g_outs): def grad(self, args, g_outs):
# forward pass - get the outputs after applying scan # forward pass - get the outputs after applying scan
scan_outputs = self(*args) scan_outputs = self(*args)
# make sure they are given as a list # make sure they are given as a list
...@@ -1338,10 +1361,14 @@ class Scan(Op): ...@@ -1338,10 +1361,14 @@ class Scan(Op):
clean_inputs = self.inputs[:self.inner_fn_start_shared] + \ clean_inputs = self.inputs[:self.inner_fn_start_shared] + \
self.inputs[self.inner_fn_start_shared + \ self.inputs[self.inner_fn_start_shared + \
self.inner_fn_end_shared:] self.inner_fn_end_shared:]
clean_inputs = [ self.copy_map.get(x,x) for x in clean_inputs]
s_inputs = [self.copy_map.get(x,x) for x in self.inputs ]
# function that computes the gradient (we sum over the gradients # function that computes the gradient (we sum over the gradients
# with respect to all outputs # with respect to all outputs
def compute_gradient(y, g_y): def compute_gradient(y, g_y):
gmap = gradient.grad_sources_inputs( \ gmp = gradient.grad_sources_inputs( \
[(y,g_y)], clean_inputs, False) [(y,g_y)], clean_inputs, False)
def zero(p): def zero(p):
try: try:
...@@ -1351,8 +1378,7 @@ class Scan(Op): ...@@ -1351,8 +1378,7 @@ class Scan(Op):
return tensor.TensorConstant(tensor.TensorType(\ return tensor.TensorConstant(tensor.TensorType(\
dtype=use_dtype, broadcastable=[]), dtype=use_dtype, broadcastable=[]),
safe_asarray._asarray(0,dtype = use_dtype)) safe_asarray._asarray(0,dtype = use_dtype))
return [gmp.get(p, zero(p)) for p in s_inputs]
return [gmap.get(p, zero(p)) for p in self.inputs]
# this are g_outs for the inner function (that computes the gradients) # this are g_outs for the inner function (that computes the gradients)
...@@ -1385,7 +1411,6 @@ class Scan(Op): ...@@ -1385,7 +1411,6 @@ class Scan(Op):
inner_gfn_outs[i] = x inner_gfn_outs[i] = x
# backwards pass # backwards pass
for i in xrange(len(inner_gfn_outs)): for i in xrange(len(inner_gfn_outs)):
if inner_gfn_outs[i] == None: if inner_gfn_outs[i] == None:
...@@ -1399,7 +1424,8 @@ class Scan(Op): ...@@ -1399,7 +1424,8 @@ class Scan(Op):
# after n_outs_not_shared ... # after n_outs_not_shared ...
g_outs[i] = tensor.zeros_like(scan_outputs[i]) g_outs[i] = tensor.zeros_like(scan_outputs[i])
except: except:
g_outs[i] = theano.tensor.constant(numpy.array(0,dtype=theano.config.floatX)) g_outs[i] = theano.tensor.constant(numpy.array(0,dtype=\
theano.config.floatX))
inner_gfn_ins = inner_g_outs + self.inputs inner_gfn_ins = inner_g_outs + self.inputs
g_args = [self.n_steps] + g_outs[:self.n_outs_not_shared] \ g_args = [self.n_steps] + g_outs[:self.n_outs_not_shared] \
...@@ -1683,7 +1709,7 @@ class ScanSpaceOptimizer(Optimizer): ...@@ -1683,7 +1709,7 @@ class ScanSpaceOptimizer(Optimizer):
if isinstance(op, Scan): if isinstance(op, Scan):
outputs = node.outputs outputs = node.outputs
store_steps = [0 for x in outputs] store_steps = [0 for x in outputs]
# check the otuputs # check the outputs
for i,out in enumerate(node.outputs): for i,out in enumerate(node.outputs):
if op.store_steps[i] == 0 : if op.store_steps[i] == 0 :
# if we do not have a range for this output # if we do not have a range for this output
...@@ -1693,43 +1719,44 @@ class ScanSpaceOptimizer(Optimizer): ...@@ -1693,43 +1719,44 @@ class ScanSpaceOptimizer(Optimizer):
if type(cl) == str: if type(cl) == str:
# if the node is actually an output, then # if the node is actually an output, then
# we need to store the entire thing # we need to store the entire thing
req_steps = 0 req_steps = None
break break
else: else:
if not isinstance(cl.op, if not isinstance(cl.op,
tensor.basic.Subtensor): tensor.basic.Subtensor):
# if any of the clients is not a subtensor # if any of the clients is not a subtensor
# we also need to store the enitre thing # we also need to store the enitre thing
req_steps = 0 req_steps = None
break break
else: else:
# if it is a tensor, and the first # if it is a tensor, and the first
# dimension is just -1 # dimension is just -1
if cl.op.idx_list[0] == -1 : if cl.op.idx_list[0] == -1 and req_steps != None:
req_steps = numpy.max([1, req_steps]) req_steps = numpy.max([1, req_steps])
else: else:
# or a constant that evaluates to # or a constant that evaluates to
# -1 # -1
try: try:
idx = opt.get_constant_value(cl.op.idx_list[0]) idx = opt.get_constant_value(\
cl.op.idx_list[0])
if idx== -1: if idx== -1:
req_steps = numpy.max([1, req_steps]) req_steps = numpy.max([1, req_steps])
else: else:
req_steps = 0 req_steps = None
break break
except: except:
req_steps = 0 req_steps = None
break break
store_steps[i] = req_steps store_steps[i] = req_steps if req_steps != None else 0
else: else:
store_steps[i] = op.store_steps[i] store_steps[i] = op.store_steps[i]
if numpy.any(store_steps!= op.store_steps): if numpy.any(store_steps!= op.store_steps):
new_scan = Scan((op.inputs, op.outputs, op.givens, new_scan = Scan((op.inputs, op.outputs, op.givens,
op.slice_to_seqs),op.n_seqs, op.n_outs, op.slice_to_seqs),op.n_seqs, op.n_outs,
op.inplace_map, op.seqs_taps, op.outs_taps, op.n_steps, op.inplace_map, op.seqs_taps, op.outs_taps, op.n_steps,
op.truncate_gradient, op.n_outs_not_shared, op.inner_fn_start_shared, op.truncate_gradient, op.n_outs_not_shared,
op.inner_fn_end_shared, op.go_backwards, op.inner_fn_start_shared, op.inner_fn_end_shared,
store_steps, op.return_steps, op.mode, op.go_backwards, store_steps, op.return_steps, op.mode,
op.inplace, name = op.fn.name).make_node(*node.inputs) op.inplace, name = op.fn.name).make_node(*node.inputs)
# we not need to replace the outputs of scan # we not need to replace the outputs of scan
for i,out in enumerate(node.outputs): for i,out in enumerate(node.outputs):
...@@ -1757,8 +1784,8 @@ def scan_make_inplace(node): ...@@ -1757,8 +1784,8 @@ def scan_make_inplace(node):
return Scan((op.inputs, op.outputs, op.givens, op.slice_to_seqs ) , op.n_seqs, return Scan((op.inputs, op.outputs, op.givens, op.slice_to_seqs ) , op.n_seqs,
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps, op.n_steps, op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps, op.n_steps,
op.truncate_gradient, op.n_outs_not_shared, op.inner_fn_start_shared, op.truncate_gradient, op.n_outs_not_shared, op.inner_fn_start_shared,
op.inner_fn_end_shared, op.go_backwards, op.store_steps, op.return_steps, op.mode, op.inner_fn_end_shared, op.go_backwards, op.store_steps, op.return_steps,
inplace=True, name = op.fn.name).make_node(*node.inputs).outputs op.mode, inplace=True, name = op.fn.name).make_node(*node.inputs).outputs
return False return False
......
...@@ -4015,6 +4015,10 @@ def verify_grad(op, pt, n_tests=2, rng=None, eps=None, tol=None, mode=None, cast ...@@ -4015,6 +4015,10 @@ def verify_grad(op, pt, n_tests=2, rng=None, eps=None, tol=None, mode=None, cast
debug mode, which can be very slow if it has to verify a lot debug mode, which can be very slow if it has to verify a lot
of intermediate computations. of intermediate computations.
:note: This op does not support multiple outputs. In tests/test_scan.py there is
an experimental verify_grad that covers that case as well by using random
projections ..
""" """
pt = [numpy.array(p) for p in pt] pt = [numpy.array(p) for p in pt]
......
...@@ -6,7 +6,29 @@ import numpy ...@@ -6,7 +6,29 @@ import numpy
import random import random
import numpy.random import numpy.random
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
'''
Questions and notes about scan that should be answered :
* Even though it does not make it publically known in
the documentation, scan allows you to set both a return_steps
flag and a store_steps flag ( the first one is a soft condition telling
you how many steps to return, the second one determines how much memory to
allocate). There is an optimization as well, that transforms return_steps to
store_steps. Questions :
- what happens if both flags are set ?
answer: whatever return_steps says is ignored, and store_steps is used
- the optimization works only with return_steps = -1; can it be made to work
with other values ?
answer: 6 Jul 2010 RP :it is a bit harry to figure out from the subtensors what
exactly you need
* Scan seems to do copies of every input variable. Is that needed?
answer : probably not, but it doesn't hurt also ( what we copy is theano variables,
which just cary information about the type / dimension of the data)
* There is some of scan functionality that is not well documented
'''
class multiple_outputs_numeric_grad: class multiple_outputs_numeric_grad:
...@@ -103,7 +125,7 @@ class multiple_outputs_numeric_grad: ...@@ -103,7 +125,7 @@ class multiple_outputs_numeric_grad:
# use it with the normal verify_grad rather than the # use it with the normal verify_grad rather than the
# copy-and-pasted one above. # copy-and-pasted one above.
# Also - add a reference to this technique in the # Also - add a reference to this technique in the
# verify_grad method so that other ops with multiple outputs can be tested. # verify_grad method so that other ops with multiple outputs can be tested. DONE - rp
def scan_project_sum(*args, **kwargs): def scan_project_sum(*args, **kwargs):
rng = theano.tensor.shared_randomstreams.RandomStreams(123) rng = theano.tensor.shared_randomstreams.RandomStreams(123)
scan_outputs, updates = theano.scan(*args, **kwargs) scan_outputs, updates = theano.scan(*args, **kwargs)
...@@ -941,6 +963,18 @@ class T_Scan(unittest.TestCase): ...@@ -941,6 +963,18 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose([ny2,ny2], nz2) assert numpy.allclose([ny2,ny2], nz2)
assert not numpy.allclose(ny1,ny2) assert not numpy.allclose(ny1,ny2)
def test_grad_of_shared(self):
x1 = theano.shared(3.)
x1.name = 'x1'
x2 = theano.tensor.vector('x2')
y, updates = theano.scan(lambda v: v*x1, sequences = x2)
m = theano.tensor.grad(y.sum(), x1)
f = theano.function([x2], m)
print f([2,3])
assert numpy.allclose(f([2,3]) , 5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论