提交 4e010efd authored 作者: Razvan Pascanu's avatar Razvan Pascanu

reverted the many copies scan used to do since it introduced bugs and the…

reverted the many copies scan used to do since it introduced bugs and the corner case it was suppose to address was fixed by Pascal L.
上级 a21a33fd
...@@ -377,7 +377,6 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -377,7 +377,6 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
seqs = [sequences] seqs = [sequences]
else: else:
seqs = sequences seqs = sequences
if not (type(outputs_info) in (list,tuple)): if not (type(outputs_info) in (list,tuple)):
outs_info = [outputs_info] outs_info = [outputs_info]
else: else:
...@@ -633,29 +632,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -633,29 +632,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# remove shared variables from the non sequences list # remove shared variables from the non sequences list
# such that we can compile the function ( the user has the option to add them when # such that we can compile the function ( the user has the option to add them when
# writing scan, because in some situations this might make the code more readable) # writing scan, because in some situations this might make the code more readable)
# Also duplicate the list of non sequences arguments to contain copies of the
# non-shared inputs ( this fixes the case when one of this inputs has a default
# update attached to it that belongs to some shared random stream )
#
# Note : In that case, scan assumes that you do not want to draw new numbers at
# every call ( you would have made the internal function do that explicitly
# if you wanted to) but rather to use that initial draw as a matrix of values
new_non_seqs = []
notshared_other_args = [] notshared_other_args = []
notshared_other_args_copies = []
for non_seq in non_seqs: for non_seq in non_seqs:
if not isinstance(non_seq, SharedVariable): if not isinstance(non_seq, SharedVariable):
if n_fixed_steps not in [-1,1]:
non_seq_copy = non_seq.type()
if non_seq.name :
non_seq_copy.name = non_seq.name + '_copy'
else:
non_seq_copy = non_seq
notshared_other_args += [non_seq] notshared_other_args += [non_seq]
notshared_other_args_copies += [non_seq_copy]
new_non_seqs += [non_seq_copy]
else:
new_non_seqs += [non_seq]
# add only the not shared variables to the arguments of the dummy # add only the not shared variables to the arguments of the dummy
# function [ a function should not get shared variables as input ] # function [ a function should not get shared variables as input ]
...@@ -663,10 +643,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -663,10 +643,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
for arg in args: for arg in args:
if not isinstance(arg, SharedVariable): if not isinstance(arg, SharedVariable):
dummy_args += [arg] dummy_args += [arg]
dummy_args += notshared_other_args_copies dummy_args += notshared_other_args
# arguments for the lambda expression that gives us the output # arguments for the lambda expression that gives us the output
# of the inner function # of the inner function
args += new_non_seqs args += non_seqs
# when we apply the lambda expression we get a mixture of update rules # when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated # and outputs that needs to be separated
...@@ -776,9 +756,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -776,9 +756,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# or defult behaviour ( like always add the extra outputs at the end !?) # or defult behaviour ( like always add the extra outputs at the end !?)
# But I did not bother implementing this, I leave it to the user to clearly # But I did not bother implementing this, I leave it to the user to clearly
# express what he/she wants to do # express what he/she wants to do
raise ValueError('There has been a terrible mistake in our input arguments' raise ValueError('Scan is totally lost. Make sure that you indicate for each'
' and scan is totally lost. Make sure that you indicate for every ' ' output what taps you want to use, or None, if you do not want to'
' output what taps you want to use, or None, if you do not want to '
' use any !') ' use any !')
inner_fn_inputs=[input.variable for input in \ inner_fn_inputs=[input.variable for input in \
dummy_f.maker.expanded_inputs[:dummy_notshared_ins+dummy_notshared_init_outs]] dummy_f.maker.expanded_inputs[:dummy_notshared_ins+dummy_notshared_init_outs]]
......
...@@ -975,6 +975,34 @@ class T_Scan(unittest.TestCase): ...@@ -975,6 +975,34 @@ class T_Scan(unittest.TestCase):
print f([2,3]) print f([2,3])
assert numpy.allclose(f([2,3]) , 5) assert numpy.allclose(f([2,3]) , 5)
def test_computing_gradient(self):
x1 = theano.tensor.scalar()
x2 = theano.shared(numpy.array([1,2,3,4,5]))
K = x2*x1
out,updates = theano.scan(lambda i,v: theano.tensor.grad(K[i], v),
sequences = theano.tensor.arange(K.shape[0]), non_sequences=x1)
f = theano.function([x1], out)
print f(3.)
assert numpy.all( f(3.) != 0. )
'''
def test_shared_updates(self):
X = theano.shared( numpy.array( [[1,2,3],[4,5,6]]))
out,updates = theano.scan( lambda :{X: X+1}, outputs_info = [], non_sequences= [],
sequences = [], n_steps = 10)
f = theano.function([],[], updates = updates)
f()
print X.value
'''
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论