提交 0b042130 authored 作者: Guillaume Desjardins's avatar Guillaume Desjardins

BUG FIX: state_updates should be initialized in MRG_RandomStreams.__init__ to

prevent different random number generators from sharing state.
上级 ae33d372
......@@ -637,10 +637,6 @@ def guess_n_streams(size, warn=True):
class MRG_RandomStreams(object):
"""Module component with similar interface to numpy.random (numpy.random.RandomState)"""
state_updates = []
"""A list of pairs of the form (input_r, output_r), representing the
update rules of all the random states generated by this RandomStreams"""
def updates(self):
return list(self.state_updates)
......@@ -655,6 +651,10 @@ class MRG_RandomStreams(object):
M2 = 2147462579, and not all 0.
"""
# A list of pairs of the form (input_r, output_r), representing the
# update rules of all the random states generated by this RandomStreams"""
self.state_updates = []
super(MRG_RandomStreams, self).__init__()
if isinstance(seed, int):
if seed == 0:
......
......@@ -677,3 +677,23 @@ class T_MRG(unittest.TestCase):
self.assertRaises(ValueError, R.binomial, size)
self.assertRaises(ValueError, R.multinomial, size, 1, [])
self.assertRaises(ValueError, R.normal, size)
def test_multiple_rng():
"""
Test that we can have multiple random number generators in parallel, and
that we can replicate the stream of one with another. This is meant to fix a
previous bug in rng_mrg.MRG_RandomStreams where state_updates was
initialized as a class variable, instead of in the __init__ method. This
would cause all MRG_RandomStreams objects to share the same state.
"""
rng1 = MRG_RandomStreams(1234)
var1 = theano.shared(numpy.ones(1))
out1 = rng1.uniform(size=(1,))
f1 = theano.function([], out1, updates={var1: out1})
assert len(rng1.state_updates) == 1
rng2 = MRG_RandomStreams(1234)
out2 = rng2.uniform(size=(1,))
f2 = theano.function([], out2, updates={var1: out2})
assert len(rng1.state_updates) == 1
assert len(rng2.state_updates) == 1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论