提交 b0cfc18d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

merge

...@@ -150,18 +150,18 @@ class T_RandomStreams(unittest.TestCase): ...@@ -150,18 +150,18 @@ class T_RandomStreams(unittest.TestCase):
def test_ndim(self): def test_ndim(self):
"""Test that the behaviour of 'ndim' optional parameter""" """Test that the behaviour of 'ndim' optional parameter"""
# 'ndim' is an optional integer parameter, specifying the length # 'ndim' is an optional integer parameter, specifying the length
# of the 'shape', placed as first argument. # of the 'shape', passed as a keyword argument.
# ndim not specified, OK # ndim not specified, OK
m1 = Module() m1 = Module()
m1.random = RandomStreams(234) m1.random = RandomStreams(utt.fetch_seed())
m1.fn = Method([], m1.random.uniform((2,2))) m1.fn = Method([], m1.random.uniform((2,2)))
made1 = m1.make() made1 = m1.make()
made1.random.initialize() made1.random.initialize()
# ndim specified, consistent with shape, OK # ndim specified, consistent with shape, OK
m2 = Module() m2 = Module()
m2.random = RandomStreams(234) m2.random = RandomStreams(utt.fetch_seed())
m2.fn = Method([], m2.random.uniform((2,2), ndim=2)) m2.fn = Method([], m2.random.uniform((2,2), ndim=2))
made2 = m2.make() made2 = m2.make()
made2.random.initialize() made2.random.initialize()
...@@ -172,7 +172,7 @@ class T_RandomStreams(unittest.TestCase): ...@@ -172,7 +172,7 @@ class T_RandomStreams(unittest.TestCase):
# ndim specified, inconsistent with shape, should raise ValueError # ndim specified, inconsistent with shape, should raise ValueError
m3 = Module() m3 = Module()
m3.random = RandomStreams(234) m3.random = RandomStreams(utt.fetch_seed())
self.assertRaises(ValueError, m3.random.uniform, (2,2), ndim=1) self.assertRaises(ValueError, m3.random.uniform, (2,2), ndim=1)
def test_uniform(self): def test_uniform(self):
...@@ -283,7 +283,7 @@ class T_RandomStreams(unittest.TestCase): ...@@ -283,7 +283,7 @@ class T_RandomStreams(unittest.TestCase):
assert numpy.all(fn_val1 == numpy_val1) assert numpy.all(fn_val1 == numpy_val1)
def test_shuffle_row_elements(self): def test_shuffle_row_elements(self):
"""Ensure RandomStreams.shuffle_row_elements generates right results""" """Test that RandomStreams.shuffle_row_elements generates the right results"""
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
# On matrices, for each row, the elements of that row should be # On matrices, for each row, the elements of that row should be
# shuffled. # shuffled.
...@@ -475,7 +475,6 @@ class T_RandomStreams(unittest.TestCase): ...@@ -475,7 +475,6 @@ class T_RandomStreams(unittest.TestCase):
low_val = [.1, .2, .3] low_val = [.1, .2, .3]
high_val = [1.1, 2.2, 3.3] high_val = [1.1, 2.2, 3.3]
seed_gen = numpy.random.RandomState(utt.fetch_seed()) seed_gen = numpy.random.RandomState(utt.fetch_seed())
numpy_rng = numpy.random.RandomState(int(seed_gen.randint(2**30))) numpy_rng = numpy.random.RandomState(int(seed_gen.randint(2**30)))
...@@ -512,7 +511,6 @@ class T_RandomStreams(unittest.TestCase): ...@@ -512,7 +511,6 @@ class T_RandomStreams(unittest.TestCase):
n_val = [1, 2, 3] n_val = [1, 2, 3]
prob_val = [.1, .2, .3] prob_val = [.1, .2, .3]
seed_gen = numpy.random.RandomState(utt.fetch_seed()) seed_gen = numpy.random.RandomState(utt.fetch_seed())
numpy_rng = numpy.random.RandomState(int(seed_gen.randint(2**30))) numpy_rng = numpy.random.RandomState(int(seed_gen.randint(2**30)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论