提交 2839b9db authored 作者: Guillaume Desjardins's avatar Guillaume Desjardins

Removed all class-member initialization from MRG_RandomStreams, RandomStreams

and SharedRandomStreams (and added associated unit-tests)
上级 a963ce9c
...@@ -680,20 +680,12 @@ class T_MRG(unittest.TestCase): ...@@ -680,20 +680,12 @@ class T_MRG(unittest.TestCase):
def test_multiple_rng(): def test_multiple_rng():
""" """
Test that we can have multiple random number generators in parallel, and Test that when we have multiple random number generators, we do not alias
that we can replicate the stream of one with another. This is meant to fix a the state_updates member. `state_updates` can be useful when attempting to
previous bug in rng_mrg.MRG_RandomStreams where state_updates was copy the (random) state between two similar theano graphs. The test is
initialized as a class variable, instead of in the __init__ method. This meant to detect a previous bug where state_updates was initialized as a
would cause all MRG_RandomStreams objects to share the same state. class-attribute, instead of the __init__ function.
""" """
rng1 = MRG_RandomStreams(1234) rng1 = MRG_RandomStreams(1234)
var1 = theano.shared(numpy.ones(1)) rng2 = MRG_RandomStreams(2392)
out1 = rng1.uniform(size=(1,)) assert rng1.state_updates is not rng2.state_updates
f1 = theano.function([], out1, updates={var1: out1})
assert len(rng1.state_updates) == 1
rng2 = MRG_RandomStreams(1234)
out2 = rng2.uniform(size=(1,))
f2 = theano.function([], out2, updates={var1: out2})
assert len(rng1.state_updates) == 1
assert len(rng2.state_updates) == 1
...@@ -122,20 +122,6 @@ class RandomStreams(Component, raw_random.RandomStreamsBase): ...@@ -122,20 +122,6 @@ class RandomStreams(Component, raw_random.RandomStreamsBase):
""" """
random_state_variables = []
"""A list of pairs of the form (input_r, output_r). This will be
over-ridden by the module instance to contain stream
generators.
"""
default_instance_seed = None
"""Instance variable should take None or integer value. Used to
seed the random number generator that provides seeds for member
streams
"""
def __init__(self, seed=None, no_warn=False): def __init__(self, seed=None, no_warn=False):
""":type seed: None or int """:type seed: None or int
...@@ -147,7 +133,13 @@ class RandomStreams(Component, raw_random.RandomStreamsBase): ...@@ -147,7 +133,13 @@ class RandomStreams(Component, raw_random.RandomStreamsBase):
if not no_warn: if not no_warn:
deprecation_warning() deprecation_warning()
super(RandomStreams, self).__init__(no_warn=True) super(RandomStreams, self).__init__(no_warn=True)
# A list of pairs of the form (input_r, output_r). This will be
# over-ridden by the module instance to contain stream generators.
self.random_state_variables = [] self.random_state_variables = []
# Instance variable should take None or integer value. Used to seed the
# random number generator that provides seeds for member streams
self.default_instance_seed = seed self.default_instance_seed = seed
def allocate(self, memo): def allocate(self, memo):
......
...@@ -682,6 +682,17 @@ class T_RandomStreams(unittest.TestCase): ...@@ -682,6 +682,17 @@ class T_RandomStreams(unittest.TestCase):
assert val1.dtype == 'int8' assert val1.dtype == 'int8'
assert numpy.all(abs(val1) <= 1) assert numpy.all(abs(val1) <= 1)
def test_multiple_rng(self):
"""
Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to
copy the (random) state between two similar theano graphs. The test is
meant to detect a previous bug where state_updates was initialized as a
class-attribute, instead of the __init__ function.
"""
rng1 = RandomStreams(1234)
rng2 = RandomStreams(2392)
assert rng1.random_state_variables is not rng2.random_state_variables
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main from theano.tests import main
......
...@@ -715,7 +715,18 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -715,7 +715,18 @@ class T_SharedRandomStreams(unittest.TestCase):
s_rng.set_value(rr, borrow=True) s_rng.set_value(rr, borrow=True)
assert rr is s_rng.container.storage[0] assert rr is s_rng.container.storage[0]
def test_multiple_rng(self):
"""
Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to
copy the (random) state between two similar theano graphs. The test is
meant to detect a previous bug where state_updates was initialized as a
class-attribute, instead of the __init__ function.
"""
rng1 = RandomStreams(1234)
rng2 = RandomStreams(2392)
assert rng1.state_updates is not rng2.state_updates
assert rng1.gen_seedgen is not rng2.gen_seedgen
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论