提交 b34cccb8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

some fix for scan when computing gradients wrt to shared variables that where…

some fix for scan when computing gradients wrt to shared variables that where not passed to the inner function (added a new test in test_scan.py
上级 1542aa1e
......@@ -352,13 +352,14 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
:param mode:
The mode used when compiling the theano function in the Scan op.
If None will use the config mode.
If None and the config mode is a a profile mode, we will create a new instance
to compute correctly the timming.
Otherwise the time spend in Scan will show up twice in the profiling, once
as the time taken by scan, and a second time as taken by the individial ops
that scan calls to do a iteration step.
The new profiler instance will be printed when python exits.
If None, it will use the config mode. If None and the config mode is set to
profile mode, it we will create a new instance of the ProfileMode in order
to compute the timming correctly.
If no new instance is created the time spend in Scan will show up twice in the
profiling, once as the time taken by scan, and the second time as the time
taken by the ops inside scan. This will be even worse for multiple cascading
scans.
The new profiler instance will be printed when python exits.
:rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a
......@@ -455,7 +456,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
sequences_taps[i] = seqs[i]['taps']
# wrap outputs info in a dictionary if they are not already
# in the same pass create a init_outs_taps dictionary and a inplace map
# in one and in the same pass create a init_outs_taps dictionary and a inplace map
for i in xrange(n_outs):
if outs_info[i]:
# If output is a dictionary, collect the number of steps the
......@@ -480,25 +481,29 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
outs_info[i] = dict(initial=outs_info[i], taps = [-1])
# if there is no initial state but there are taps
# then return an error because it makes no sense
elif (not outs_info[i].get('initial',None)) and(outs_info[i].get('taps',None)):
elif (not outs_info[i].get('initial',None)) and \
(outs_info[i].get('taps',None)):
raise ValueError('If you are using slices of an output you need to '\
'provide a initial state for it', outs_info[i])
# if there is an intial state but no tap, we will add the default value for
# taps, namely [-1] ( previous value); not that this will happen even though
# you have provided for taps the value None, which is a bit strange (why would
# one provide an initial state but tell scan not to use it ? ), just that
# in that case we will throw in a warning message pointing out this inconsistency
elif outs_info[i].get('initial',None) and ( not outs_info[i].get('taps',None)):
# if there is an intial state but no tap, we will add the default value
# for taps, namely [-1] ( previous value); not that this will happen
# even though you have provided for taps the value None, which is a bit
# strange (why would one provide an initial state but tell scan not to
# use it ? ), just that in that case we will throw in a warning message
# pointing out this inconsistency
elif outs_info[i].get('initial',None) and \
( not outs_info[i].get('taps',None)):
if outs_info[i].has_key('taps'):
warning('You are providing a initial state for an output, but yet tell scan'
'not to use it. Why? Scan will overwrite this setting and use the previous'
'value of the provided initial state. If this is not what you wanted, check'
'your code and do not provide the initial state')
warning('You are providing a initial state for an output and then '
'tell scan not to use it. Why? Scan will overwrite this setting'
' and use the previous value of the provided initial state. If'
' this is not what you wanted, check your code and do not '
'provide the initial state')
outs_info[i]['taps'] = [-1]
else:
# if the output is a None then replace it with an empty dictionary for easing
# up dealing with this case later one ( we can directly call .has_key and things
# like this
# if the output is a None then replace it with an empty dictionary for
# easing up dealing with this case later one ( we can directly call .has_key
# and things like this
outs_info[i] = dict()
store_steps += [0]
......@@ -507,16 +512,16 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# is how the Scan Op expects this information, separeted from the variables
outputs_taps[i] = outs_info[i]['taps']
if outs_info[i].get('inplace', None):
# The same is true for the inplace info; it has to go into a separate dictionary
# based on index; Note that the input we're replacing should also come as an
# index, therefore we have to look for it here
# The same is true for the inplace info; it has to go into a separate
# dictionary based on index; Note that the input we're replacing should also
# come as an index, therefore we have to look for it at this point
found = None
for k in xrange(n_seqs):
if seqs[k].get('input', None) == outs_info[i].get('inplace',None):
found = k
if found != None:
# NOTE : inplace_map is identical to destroy_map, i.e. it tells what output
# is computed inplace of what input !!
# NOTE : inplace_map is identical to destroy_map, i.e. it tells what
# output is computed inplace of what input !!
inplace_map[i] = found
else:
raise ValueError('Asked to compute in place of a non-input variable',\
......@@ -528,8 +533,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# function to detect shared variables and their updates
# and to construct a new and complete list of inputs and outputs
args = [] # list of arguments
dummy_notshared_ins = 0 # number of arguments corresponding to input sequences
dummy_notshared_init_outs = 0 # number of arguments corresponding to output sequences
dummy_notshared_ins = 0 # number of arguments corresponding to input seqs
dummy_notshared_init_outs = 0 # number of arguments corresponding to output seqs
slice_to_seqs = [] # for each slice index of the corresponding input
# go through sequences picking up time slices as needed
for i,seq in enumerate(seqs):
......@@ -626,15 +631,15 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# anything ..
# remove shared variables from the non sequences list
# such that we can compile the function ( the user has the option to add them when writing
# scan, because in some situations this might make the code more readable)
# Also duplicate the list of non sequences arguments to contain copies of the non-shared
# inputs ( this fixes the case when one of this inputs has a default update attached to it
# that belongs to some shared random stream ).
# such that we can compile the function ( the user has the option to add them when
# writing scan, because in some situations this might make the code more readable)
# Also duplicate the list of non sequences arguments to contain copies of the
# non-shared inputs ( this fixes the case when one of this inputs has a default
# update attached to it that belongs to some shared random stream )
#
# Note : In that case, scan assumes that you do not want to draw new numbers at every call ( you
# would have made the internal function do that explicitly if you wanted to) but rather to
# use that initial draw as a matrix of values
# Note : In that case, scan assumes that you do not want to draw new numbers at
# every call ( you would have made the internal function do that explicitly
# if you wanted to) but rather to use that initial draw as a matrix of values
new_non_seqs = []
notshared_other_args = []
notshared_other_args_copies = []
......@@ -642,6 +647,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if not isinstance(non_seq, SharedVariable):
if n_fixed_steps not in [-1,1]:
non_seq_copy = non_seq.type()
if non_seq.name :
non_seq_copy.name = non_seq.name + '_copy'
else:
non_seq_copy = non_seq
notshared_other_args += [non_seq]
......@@ -788,13 +795,16 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# Skip the slices that we've added to the inner_fn which will be the first elements
# of f.maker.epanded_inputs and which we know that are not shared
fromIdx = dummy_notshared_ins + dummy_notshared_init_outs
copy_map = {}
for input in dummy_f.maker.expanded_inputs[fromIdx:] :
# If input is a shared variable that gets updated, then
# this shared variable will be an output of our inner function
if isinstance(input.variable, SharedVariable) and input.update:
# Create a copy of it
new_var = input.variable.type()
if input.variable.name:
new_var.name = input.variable.name + '_copy'
copy_map[new_var] = input.variable
inner_fn_inputs.append(new_var)
# add it to the slices at the end
slice_to_seqs += [ n_extended_outs ]
......@@ -818,9 +828,13 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# make sure that we do not add the same shared variable twice
if isinstance(input.variable, SharedVariable) and not input.update:
shared_non_seqs += [input.variable]
inner_fn_inputs += [input.variable.type() ]
new_var = input.variable.type()
if input.variable.name:
new_var.name = input.variable.name + '_copy'
inner_fn_inputs += [new_var]
slice_to_seqs += [ n_extended_outs]
givens[input.variable] = inner_fn_inputs[-1]
copy_map[inner_fn_inputs[-1]] = input.variable
elif not isinstance(input.variable, SharedVariable):
# also add the normal tensor that are non sequences at the
# end of the inputs intertwingled with the shared variables
......@@ -849,7 +863,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# a gradient
n_outs, inner_fn_notshared_ins_idx, inner_fn_shared_ins_idx,
go_backwards, store_steps, return_steps, mode, name = name )
# Shortcut for attaching this property to the Scan op
local_op.copy_map = copy_map
# Call the object on the input sequences, initial values for outs,
# and non sequences
for seq in seqs :
......@@ -1200,7 +1215,8 @@ class Scan(Op):
else:
# check if you are using past value .. through in a warning and do not
# work inplace
if inplace_map.has_key(i) and seqs_taps.has_key(inplace_map[i]) and seqs_taps[inplace_map[i]] < 0:
if inplace_map.has_key(i) and seqs_taps.has_key(inplace_map[i]) and\
seqs_taps[inplace_map[i]] < 0:
warning('Can not work inplace because of past values')
if self.store_steps[i] == 1 :
y+= [ None ]
......@@ -1262,10 +1278,10 @@ class Scan(Op):
else:
k = i + sz + tap_value
if k < 0:
# past value not provided.. issue a warning and use 0s of the
# correct dtype
fn_args += [numpy.zeros(args[j+n_seqs][0].shape, dtype =
args[j+n_sqs][0].dtype)]
# past value not provided.. issue a warning and use
# 0s of the correct dtype
fn_args += [numpy.zeros(args[j+n_seqs][0].shape, \
dtype = args[j+n_sqs][0].dtype)]
warning(('Past value %d for output %d not given in '
'inital out') % (j,tap_value))
else:
......@@ -1280,7 +1296,8 @@ class Scan(Op):
else:
# storing only the last k
# get what idx we want
req_idx = (self.idx_store_steps[j] + tap_value + self.store_steps[j])
req_idx = (self.idx_store_steps[j] + tap_value + \
self.store_steps[j])
# we need this modula self.store_steps[j]
req_idx = req_idx % self.store_steps[j]
fn_args += [y[j][req_idx] ]
......@@ -1292,42 +1309,48 @@ class Scan(Op):
#update outputs
for j in xrange(n_outs):
if self.store_steps[j] <1:
# if you have provided no size for the missing output you might find yourself
# here with a incorect array .. if that happens realocate memory for the
# needed array
# if you have provided no size for the missing output you might
# find yourself here with a incorect array .. if that happens
# realocate memory for the needed array
try :
if hasattr(something[j],'dtype') and (y[j].dtype != something[j].dtype) :
if hasattr(something[j],'dtype') and (y[j].dtype != \
something[j].dtype) :
raise ValueError('wrong dtype')
y[j][i] = something[j]
except :
y[j]= numpy.empty((n_steps,)+something[j].shape, dtype= something[j].dtype)
y[j]= numpy.empty((n_steps,)+something[j].shape, dtype= \
something[j].dtype)
y[j][i] = something[j]
elif self.store_steps[j] == 1:
try:
if hasattr(something[j],'dtype') and y[j].dtype != something[j].dtype:
if hasattr(something[j],'dtype') and y[j].dtype != \
something[j].dtype:
raise ValueError('wrong dtype')
y[j] = something[j]
except:
y[j] = numpy.empty( something[j].shape, dtype = something[j].dtype)
y[j] = numpy.empty( something[j].shape, dtype = \
something[j].dtype)
y[j] = something[j]
else:
try:
if hasattr(something[j],'dtype') and y[j].dtype != something[j].dtype:
if hasattr(something[j],'dtype') and y[j].dtype != \
something[j].dtype:
raise ValueError('worng dtype')
y[j][self.idx_store_steps[j]] = something[j]
self.idx_store_steps[j] = (self.idx_store_steps[j] + 1) % self.store_steps[j]
self.idx_store_steps[j] = (self.idx_store_steps[j] + 1) %\
self.store_steps[j]
except:
y[j] = numpy.empty( (self.store_steps[j],)+something[j].shape, \
dtype = something[j].dtype)
y[j][idx_store_steps[j]] = something[j]
self.idx_store_steps[j] = (self.idx_store_steps[j] + 1) % self.store_steps[j]
self.idx_store_steps[j] = (self.idx_store_steps[j] + 1) %\
self.store_steps[j]
return y
def grad(self, args, g_outs):
# forward pass - get the outputs after applying scan
scan_outputs = self(*args)
# make sure they are given as a list
......@@ -1338,10 +1361,14 @@ class Scan(Op):
clean_inputs = self.inputs[:self.inner_fn_start_shared] + \
self.inputs[self.inner_fn_start_shared + \
self.inner_fn_end_shared:]
clean_inputs = [ self.copy_map.get(x,x) for x in clean_inputs]
s_inputs = [self.copy_map.get(x,x) for x in self.inputs ]
# function that computes the gradient (we sum over the gradients
# with respect to all outputs
def compute_gradient(y, g_y):
gmap = gradient.grad_sources_inputs( \
gmp = gradient.grad_sources_inputs( \
[(y,g_y)], clean_inputs, False)
def zero(p):
try:
......@@ -1351,8 +1378,7 @@ class Scan(Op):
return tensor.TensorConstant(tensor.TensorType(\
dtype=use_dtype, broadcastable=[]),
safe_asarray._asarray(0,dtype = use_dtype))
return [gmap.get(p, zero(p)) for p in self.inputs]
return [gmp.get(p, zero(p)) for p in s_inputs]
# this are g_outs for the inner function (that computes the gradients)
......@@ -1385,7 +1411,6 @@ class Scan(Op):
inner_gfn_outs[i] = x
# backwards pass
for i in xrange(len(inner_gfn_outs)):
if inner_gfn_outs[i] == None:
......@@ -1399,7 +1424,8 @@ class Scan(Op):
# after n_outs_not_shared ...
g_outs[i] = tensor.zeros_like(scan_outputs[i])
except:
g_outs[i] = theano.tensor.constant(numpy.array(0,dtype=theano.config.floatX))
g_outs[i] = theano.tensor.constant(numpy.array(0,dtype=\
theano.config.floatX))
inner_gfn_ins = inner_g_outs + self.inputs
g_args = [self.n_steps] + g_outs[:self.n_outs_not_shared] \
......@@ -1683,7 +1709,7 @@ class ScanSpaceOptimizer(Optimizer):
if isinstance(op, Scan):
outputs = node.outputs
store_steps = [0 for x in outputs]
# check the otuputs
# check the outputs
for i,out in enumerate(node.outputs):
if op.store_steps[i] == 0 :
# if we do not have a range for this output
......@@ -1693,43 +1719,44 @@ class ScanSpaceOptimizer(Optimizer):
if type(cl) == str:
# if the node is actually an output, then
# we need to store the entire thing
req_steps = 0
req_steps = None
break
else:
if not isinstance(cl.op,
tensor.basic.Subtensor):
# if any of the clients is not a subtensor
# we also need to store the enitre thing
req_steps = 0
req_steps = None
break
else:
# if it is a tensor, and the first
# dimension is just -1
if cl.op.idx_list[0] == -1 :
if cl.op.idx_list[0] == -1 and req_steps != None:
req_steps = numpy.max([1, req_steps])
else:
# or a constant that evaluates to
# -1
try:
idx = opt.get_constant_value(cl.op.idx_list[0])
idx = opt.get_constant_value(\
cl.op.idx_list[0])
if idx== -1:
req_steps = numpy.max([1, req_steps])
else:
req_steps = 0
req_steps = None
break
except:
req_steps = 0
req_steps = None
break
store_steps[i] = req_steps
store_steps[i] = req_steps if req_steps != None else 0
else:
store_steps[i] = op.store_steps[i]
if numpy.any(store_steps!= op.store_steps):
new_scan = Scan((op.inputs, op.outputs, op.givens,
op.slice_to_seqs),op.n_seqs, op.n_outs,
op.inplace_map, op.seqs_taps, op.outs_taps, op.n_steps,
op.truncate_gradient, op.n_outs_not_shared, op.inner_fn_start_shared,
op.inner_fn_end_shared, op.go_backwards,
store_steps, op.return_steps, op.mode,
op.truncate_gradient, op.n_outs_not_shared,
op.inner_fn_start_shared, op.inner_fn_end_shared,
op.go_backwards, store_steps, op.return_steps, op.mode,
op.inplace, name = op.fn.name).make_node(*node.inputs)
# we not need to replace the outputs of scan
for i,out in enumerate(node.outputs):
......@@ -1757,8 +1784,8 @@ def scan_make_inplace(node):
return Scan((op.inputs, op.outputs, op.givens, op.slice_to_seqs ) , op.n_seqs,
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps, op.n_steps,
op.truncate_gradient, op.n_outs_not_shared, op.inner_fn_start_shared,
op.inner_fn_end_shared, op.go_backwards, op.store_steps, op.return_steps, op.mode,
inplace=True, name = op.fn.name).make_node(*node.inputs).outputs
op.inner_fn_end_shared, op.go_backwards, op.store_steps, op.return_steps,
op.mode, inplace=True, name = op.fn.name).make_node(*node.inputs).outputs
return False
......
......@@ -4015,6 +4015,10 @@ def verify_grad(op, pt, n_tests=2, rng=None, eps=None, tol=None, mode=None, cast
debug mode, which can be very slow if it has to verify a lot
of intermediate computations.
:note: This op does not support multiple outputs. In tests/test_scan.py there is
an experimental verify_grad that covers that case as well by using random
projections ..
"""
pt = [numpy.array(p) for p in pt]
......
......@@ -6,7 +6,29 @@ import numpy
import random
import numpy.random
from theano.tests import unittest_tools as utt
'''
Questions and notes about scan that should be answered :
* Even though it does not make it publically known in
the documentation, scan allows you to set both a return_steps
flag and a store_steps flag ( the first one is a soft condition telling
you how many steps to return, the second one determines how much memory to
allocate). There is an optimization as well, that transforms return_steps to
store_steps. Questions :
- what happens if both flags are set ?
answer: whatever return_steps says is ignored, and store_steps is used
- the optimization works only with return_steps = -1; can it be made to work
with other values ?
answer: 6 Jul 2010 RP :it is a bit harry to figure out from the subtensors what
exactly you need
* Scan seems to do copies of every input variable. Is that needed?
answer : probably not, but it doesn't hurt also ( what we copy is theano variables,
which just cary information about the type / dimension of the data)
* There is some of scan functionality that is not well documented
'''
class multiple_outputs_numeric_grad:
......@@ -103,7 +125,7 @@ class multiple_outputs_numeric_grad:
# use it with the normal verify_grad rather than the
# copy-and-pasted one above.
# Also - add a reference to this technique in the
# verify_grad method so that other ops with multiple outputs can be tested.
# verify_grad method so that other ops with multiple outputs can be tested. DONE - rp
def scan_project_sum(*args, **kwargs):
rng = theano.tensor.shared_randomstreams.RandomStreams(123)
scan_outputs, updates = theano.scan(*args, **kwargs)
......@@ -941,6 +963,18 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose([ny2,ny2], nz2)
assert not numpy.allclose(ny1,ny2)
def test_grad_of_shared(self):
x1 = theano.shared(3.)
x1.name = 'x1'
x2 = theano.tensor.vector('x2')
y, updates = theano.scan(lambda v: v*x1, sequences = x2)
m = theano.tensor.grad(y.sum(), x1)
f = theano.function([x2], m)
print f([2,3])
assert numpy.allclose(f([2,3]) , 5)
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论