提交 3c268cc9 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

scan now makes copies of the other arguments

上级 4f88815c
...@@ -629,10 +629,24 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -629,10 +629,24 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# remove shared variables from the non sequences list # remove shared variables from the non sequences list
# such that we can compile the function ( the user has the option to add them when writing # such that we can compile the function ( the user has the option to add them when writing
# scan, because in some situations this might make the code more readable) # scan, because in some situations this might make the code more readable)
# Also duplicate the list of non sequences arguments to contain copies of the non-shared
# inputs ( this fixes the case when one of this inputs has a default update attached to it
# that belongs to some shared random stream ).
#
# Note : In that case, scan assumes that you do not want to draw new numbers at every call ( you
# would have made the internal function do that explicitly if you wanted to) but rather to
# use that initial draw as a matrix of values
new_non_seqs = []
notshared_other_args = [] notshared_other_args = []
notshared_other_args_copies = []
for non_seq in non_seqs: for non_seq in non_seqs:
if not isinstance(non_seq, SharedVariable): if not isinstance(non_seq, SharedVariable):
non_seq_copy = non_seq.type()
notshared_other_args += [non_seq] notshared_other_args += [non_seq]
notshared_other_args_copies += [non_seq_copy]
new_non_seqs += [non_seq_copy]
else:
new_non_seqs += [non_seq]
# add only the not shared variables to the arguments of the dummy # add only the not shared variables to the arguments of the dummy
# function [ a function should not get shared variables as input ] # function [ a function should not get shared variables as input ]
...@@ -640,10 +654,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -640,10 +654,10 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
for arg in args: for arg in args:
if not isinstance(arg, SharedVariable): if not isinstance(arg, SharedVariable):
dummy_args += [arg] dummy_args += [arg]
dummy_args += notshared_other_args dummy_args += notshared_other_args_copies
# arguments for the lambda expression that gives us the output # arguments for the lambda expression that gives us the output
# of the inner function # of the inner function
args += non_seqs args += new_non_seqs
# when we apply the lambda expression we get a mixture of update rules # when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated # and outputs that needs to be separated
...@@ -726,6 +740,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -726,6 +740,7 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
not isinstance(x,SharedVariable) and not isinstance(x,gof.Constant), \ not isinstance(x,SharedVariable) and not isinstance(x,gof.Constant), \
gof.graph.inputs(dummy_args)), outputs, updates = updates, mode = compile.mode.Mode(linker='py',optimizer=None)) gof.graph.inputs(dummy_args)), outputs, updates = updates, mode = compile.mode.Mode(linker='py',optimizer=None))
else: else:
print [printing.pp(x) for x in dummy_args]
dummy_f = function(filter(lambda x: isinstance(x, gof.Variable) and \ dummy_f = function(filter(lambda x: isinstance(x, gof.Variable) and \
not isinstance(x,SharedVariable) and not isinstance(x,gof.Constant), \ not isinstance(x,SharedVariable) and not isinstance(x,gof.Constant), \
dummy_args), outputs, updates = updates, mode = compile.mode.Mode(linker='py',optimizer=None)) dummy_args), outputs, updates = updates, mode = compile.mode.Mode(linker='py',optimizer=None))
......
...@@ -922,7 +922,24 @@ class T_Scan(unittest.TestCase): ...@@ -922,7 +922,24 @@ class T_Scan(unittest.TestCase):
assert len(analytic_grad[0]) == 3 assert len(analytic_grad[0]) == 3
def test_draw_as_input_to_scan(self):
trng = theano.tensor.shared_randomstreams.RandomStreams(123)
x = theano.tensor.matrix('x')
y = trng.binomial(size = x.shape, p = x)
z,updates = theano.scan(lambda a:a, non_sequences=y, n_steps=2)
f = theano.function([x],[y,z], updates = updates)
rng = numpy.random.RandomState(utt.fetch_seed())
nx = rng.uniform( size = (10,10) )
ny1,nz1 = f(nx)
ny2,nz2 = f(nx)
assert numpy.allclose([ny1,ny1], nz1)
assert numpy.allclose([ny2,ny2], nz2)
assert not numpy.allclose(ny1,ny2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论