提交 9189850b authored 作者: Guillaume Desjardins's avatar Guillaume Desjardins

Added unit-test for transfering random states between theano graphs

上级 ad55f02a
...@@ -678,7 +678,8 @@ class T_MRG(unittest.TestCase): ...@@ -678,7 +678,8 @@ class T_MRG(unittest.TestCase):
self.assertRaises(ValueError, R.multinomial, size, 1, []) self.assertRaises(ValueError, R.multinomial, size, 1, [])
self.assertRaises(ValueError, R.normal, size) self.assertRaises(ValueError, R.normal, size)
def test_multiple_rng():
def test_multiple_rng_aliasing():
""" """
Test that when we have multiple random number generators, we do not alias Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to the state_updates member. `state_updates` can be useful when attempting to
...@@ -689,3 +690,23 @@ def test_multiple_rng(): ...@@ -689,3 +690,23 @@ def test_multiple_rng():
rng1 = MRG_RandomStreams(1234) rng1 = MRG_RandomStreams(1234)
rng2 = MRG_RandomStreams(2392) rng2 = MRG_RandomStreams(2392)
assert rng1.state_updates is not rng2.state_updates assert rng1.state_updates is not rng2.state_updates
def test_random_state_transfer():
"""
Test that random state can be transferred from one theano graph to another.
"""
class Graph():
def __init__(self, seed=123):
self.rng = MRG_RandomStreams(seed)
self.y = self.rng.uniform(size=(1,))
g1 = Graph(seed=123)
f1 = theano.function([], g1.y)
g2 = Graph(seed=987)
f2 = theano.function([], g2.y)
g2.rng.rstate = g1.rng.rstate
for (su1, su2) in zip(g1.rng.state_updates, g2.rng.state_updates):
su2[0].set_value(su1[0].get_value())
numpy.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
...@@ -715,7 +715,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -715,7 +715,7 @@ class T_SharedRandomStreams(unittest.TestCase):
s_rng.set_value(rr, borrow=True) s_rng.set_value(rr, borrow=True)
assert rr is s_rng.container.storage[0] assert rr is s_rng.container.storage[0]
def test_multiple_rng(self): def test_multiple_rng_aliasing(self):
""" """
Test that when we have multiple random number generators, we do not alias Test that when we have multiple random number generators, we do not alias
the state_updates member. `state_updates` can be useful when attempting to the state_updates member. `state_updates` can be useful when attempting to
...@@ -728,6 +728,24 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -728,6 +728,24 @@ class T_SharedRandomStreams(unittest.TestCase):
assert rng1.state_updates is not rng2.state_updates assert rng1.state_updates is not rng2.state_updates
assert rng1.gen_seedgen is not rng2.gen_seedgen assert rng1.gen_seedgen is not rng2.gen_seedgen
def test_random_state_transfer(self):
"""
Test that random state can be transferred from one theano graph to another.
"""
class Graph():
def __init__(self, seed=123):
self.rng = RandomStreams(seed)
self.y = self.rng.uniform(size=(1,))
g1 = Graph(seed=123)
f1 = function([], g1.y)
g2 = Graph(seed=987)
f2 = function([], g2.y)
for (su1, su2) in zip(g1.rng.state_updates, g2.rng.state_updates):
su2[0].set_value(su1[0].get_value())
numpy.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main from theano.tests import main
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论