提交 8bdd7392 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge

......@@ -377,7 +377,6 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
seqs = [sequences]
else:
seqs = sequences
if not (type(outputs_info) in (list,tuple)):
outs_info = [outputs_info]
else:
......@@ -633,29 +632,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# remove shared variables from the non sequences list
# such that we can compile the function ( the user has the option to add them when
# writing scan, because in some situations this might make the code more readable)
# Also duplicate the list of non sequences arguments to contain copies of the
# non-shared inputs ( this fixes the case when one of this inputs has a default
# update attached to it that belongs to some shared random stream )
#
# Note : In that case, scan assumes that you do not want to draw new numbers at
# every call ( you would have made the internal function do that explicitly
# if you wanted to) but rather to use that initial draw as a matrix of values
new_non_seqs = []
notshared_other_args = []
notshared_other_args_copies = []
for non_seq in non_seqs:
if not isinstance(non_seq, SharedVariable):
if n_fixed_steps not in [-1,1]:
non_seq_copy = non_seq.type()
if non_seq.name :
non_seq_copy.name = non_seq.name + '_copy'
else:
non_seq_copy = non_seq
notshared_other_args += [non_seq]
notshared_other_args_copies += [non_seq_copy]
new_non_seqs += [non_seq_copy]
else:
new_non_seqs += [non_seq]
# add only the not shared variables to the arguments of the dummy
# function [ a function should not get shared variables as input ]
......@@ -663,10 +643,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
for arg in args:
if not isinstance(arg, SharedVariable):
dummy_args += [arg]
dummy_args += notshared_other_args_copies
dummy_args += notshared_other_args
# arguments for the lambda expression that gives us the output
# of the inner function
args += new_non_seqs
args += non_seqs
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
......@@ -776,9 +756,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# or defult behaviour ( like always add the extra outputs at the end !?)
# But I did not bother implementing this, I leave it to the user to clearly
# express what he/she wants to do
raise ValueError('There has been a terrible mistake in our input arguments'
' and scan is totally lost. Make sure that you indicate for every '
' output what taps you want to use, or None, if you do not want to '
raise ValueError('Scan is totally lost. Make sure that you indicate for each'
' output what taps you want to use, or None, if you do not want to'
' use any !')
inner_fn_inputs=[input.variable for input in \
dummy_f.maker.expanded_inputs[:dummy_notshared_ins+dummy_notshared_init_outs]]
......
......@@ -975,6 +975,34 @@ class T_Scan(unittest.TestCase):
print f([2,3])
assert numpy.allclose(f([2,3]) , 5)
def test_computing_gradient(self):
x1 = theano.tensor.scalar()
x2 = theano.shared(numpy.array([1,2,3,4,5]))
K = x2*x1
out,updates = theano.scan(lambda i,v: theano.tensor.grad(K[i], v),
sequences = theano.tensor.arange(K.shape[0]), non_sequences=x1)
f = theano.function([x1], out)
print f(3.)
assert numpy.all( f(3.) != 0. )
'''
def test_shared_updates(self):
X = theano.shared( numpy.array( [[1,2,3],[4,5,6]]))
out,updates = theano.scan( lambda :{X: X+1}, outputs_info = [], non_sequences= [],
sequences = [], n_steps = 10)
f = theano.function([],[], updates = updates)
f()
print X.value
'''
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论