提交 2839b9db authored 作者: Guillaume Desjardins's avatar Guillaume Desjardins

Removed all class-member initialization from MRG_RandomStreams, RandomStreams

and SharedRandomStreams (and added associated unit-tests)
上级 a963ce9c
......@@ -680,20 +680,12 @@ class T_MRG(unittest.TestCase):
def test_multiple_rng():
"""
Test that we can have multiple random number generators in parallel, and
that we can replicate the stream of one with another. This is meant to fix a
previous bug in rng_mrg.MRG_RandomStreams where state_updates was
initialized as a class variable, instead of in the __init__ method. This
would cause all MRG_RandomStreams objects to share the same state.
Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to
copy the (random) state between two similar theano graphs. The test is
meant to detect a previous bug where state_updates was initialized as a
class-attribute, instead of the __init__ function.
"""
rng1 = MRG_RandomStreams(1234)
var1 = theano.shared(numpy.ones(1))
out1 = rng1.uniform(size=(1,))
f1 = theano.function([], out1, updates={var1: out1})
assert len(rng1.state_updates) == 1
rng2 = MRG_RandomStreams(1234)
out2 = rng2.uniform(size=(1,))
f2 = theano.function([], out2, updates={var1: out2})
assert len(rng1.state_updates) == 1
assert len(rng2.state_updates) == 1
rng2 = MRG_RandomStreams(2392)
assert rng1.state_updates is not rng2.state_updates
......@@ -122,20 +122,6 @@ class RandomStreams(Component, raw_random.RandomStreamsBase):
"""
random_state_variables = []
"""A list of pairs of the form (input_r, output_r). This will be
over-ridden by the module instance to contain stream
generators.
"""
default_instance_seed = None
"""Instance variable should take None or integer value. Used to
seed the random number generator that provides seeds for member
streams
"""
def __init__(self, seed=None, no_warn=False):
""":type seed: None or int
......@@ -147,7 +133,13 @@ class RandomStreams(Component, raw_random.RandomStreamsBase):
if not no_warn:
deprecation_warning()
super(RandomStreams, self).__init__(no_warn=True)
# A list of pairs of the form (input_r, output_r). This will be
# over-ridden by the module instance to contain stream generators.
self.random_state_variables = []
# Instance variable should take None or integer value. Used to seed the
# random number generator that provides seeds for member streams
self.default_instance_seed = seed
def allocate(self, memo):
......
......@@ -682,6 +682,17 @@ class T_RandomStreams(unittest.TestCase):
assert val1.dtype == 'int8'
assert numpy.all(abs(val1) <= 1)
def test_multiple_rng(self):
"""
Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to
copy the (random) state between two similar theano graphs. The test is
meant to detect a previous bug where state_updates was initialized as a
class-attribute, instead of the __init__ function.
"""
rng1 = RandomStreams(1234)
rng2 = RandomStreams(2392)
assert rng1.random_state_variables is not rng2.random_state_variables
if __name__ == '__main__':
from theano.tests import main
......
......@@ -715,7 +715,18 @@ class T_SharedRandomStreams(unittest.TestCase):
s_rng.set_value(rr, borrow=True)
assert rr is s_rng.container.storage[0]
def test_multiple_rng(self):
"""
Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to
copy the (random) state between two similar theano graphs. The test is
meant to detect a previous bug where state_updates was initialized as a
class-attribute, instead of the __init__ function.
"""
rng1 = RandomStreams(1234)
rng2 = RandomStreams(2392)
assert rng1.state_updates is not rng2.state_updates
assert rng1.gen_seedgen is not rng2.gen_seedgen
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论