提交 767a7312 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use utt.fetch_seed() instead of a fixed seed in test_sharedrandomstreams

上级 b90ec1d2
...@@ -2,7 +2,7 @@ __docformat__ = "restructuredtext en" ...@@ -2,7 +2,7 @@ __docformat__ = "restructuredtext en"
import sys import sys
import unittest import unittest
import numpy import numpy
from theano.tensor import raw_random from theano.tensor import raw_random
from theano.tensor.shared_randomstreams import RandomStreams from theano.tensor.shared_randomstreams import RandomStreams
...@@ -11,7 +11,7 @@ from theano import function ...@@ -11,7 +11,7 @@ from theano import function
from theano import tensor from theano import tensor
from theano import compile, gof from theano import compile, gof
from theano.tests import unittest_tools from theano.tests import unittest_tools as utt
class T_SharedRandomStreams(unittest.TestCase): class T_SharedRandomStreams(unittest.TestCase):
...@@ -30,7 +30,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -30,7 +30,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert isinstance(rv_u.rng.value, numpy.random.RandomState) assert isinstance(rv_u.rng.value, numpy.random.RandomState)
def test_basics(self): def test_basics(self):
random = RandomStreams(234) random = RandomStreams(utt.fetch_seed())
fn = function([], random.uniform((2,2)), updates=random.updates()) fn = function([], random.uniform((2,2)), updates=random.updates())
gn = function([], random.normal((2,2)), updates=random.updates()) gn = function([], random.normal((2,2)), updates=random.updates())
...@@ -39,7 +39,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -39,7 +39,7 @@ class T_SharedRandomStreams(unittest.TestCase):
gn_val0 = gn() gn_val0 = gn()
rng_seed = numpy.random.RandomState(234).randint(2**30) rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0 #print fn_val0
...@@ -58,12 +58,12 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -58,12 +58,12 @@ class T_SharedRandomStreams(unittest.TestCase):
random = RandomStreams(234) random = RandomStreams(234)
fn = function([], random.uniform((2,2)), updates=random.updates()) fn = function([], random.uniform((2,2)), updates=random.updates())
random.seed(888) random.seed(utt.fetch_seed())
fn_val0 = fn() fn_val0 = fn()
fn_val1 = fn() fn_val1 = fn()
rng_seed = numpy.random.RandomState(888).randint(2**30) rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0 #print fn_val0
...@@ -80,7 +80,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -80,7 +80,7 @@ class T_SharedRandomStreams(unittest.TestCase):
out = random.uniform((2,2)) out = random.uniform((2,2))
fn = function([], out, updates=random.updates()) fn = function([], out, updates=random.updates())
random.seed(888) random.seed(utt.fetch_seed())
rng = numpy.random.RandomState() rng = numpy.random.RandomState()
rng.set_state(random[out.rng].get_state()) #tests getitem rng.set_state(random[out.rng].get_state()) #tests getitem
...@@ -100,8 +100,8 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -100,8 +100,8 @@ class T_SharedRandomStreams(unittest.TestCase):
random.seed(888) random.seed(888)
rng = numpy.random.RandomState(823874) rng = numpy.random.RandomState(utt.fetch_seed())
random[out.rng] = numpy.random.RandomState(823874) random[out.rng] = numpy.random.RandomState(utt.fetch_seed())
fn_val0 = fn() fn_val0 = fn()
fn_val1 = fn() fn_val1 = fn()
...@@ -111,15 +111,15 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -111,15 +111,15 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_val1 == numpy_val1) assert numpy.all(fn_val1 == numpy_val1)
def test_permutation(self): def test_permutation(self):
"""Test that RandomStreams.uniform generates the same results as numpy""" """Test that RandomStreams.permutation generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
random = RandomStreams(234) random = RandomStreams(utt.fetch_seed())
fn = function([], random.permutation((20,), 10), updates=random.updates()) fn = function([], random.permutation((20,), 10), updates=random.updates())
fn_val0 = fn() fn_val0 = fn()
fn_val1 = fn() fn_val1 = fn()
rng_seed = numpy.random.RandomState(234).randint(2**30) rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
# rng.permutation outputs one vector at a time, so we iterate. # rng.permutation outputs one vector at a time, so we iterate.
...@@ -132,13 +132,13 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -132,13 +132,13 @@ class T_SharedRandomStreams(unittest.TestCase):
def test_multinomial(self): def test_multinomial(self):
"""Test that RandomStreams.multinomial generates the same results as numpy""" """Test that RandomStreams.multinomial generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
random = RandomStreams(234) random = RandomStreams(utt.fetch_seed())
fn = function([], random.multinomial((4,4), 1, [0.1]*10), updates=random.updates()) fn = function([], random.multinomial((4,4), 1, [0.1]*10), updates=random.updates())
fn_val0 = fn() fn_val0 = fn()
fn_val1 = fn() fn_val1 = fn()
rng_seed = numpy.random.RandomState(234).randint(2**30) rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
numpy_val0 = rng.multinomial(1, [0.1]*10, size=(4,4)) numpy_val0 = rng.multinomial(1, [0.1]*10, size=(4,4))
numpy_val1 = rng.multinomial(1, [0.1]*10, size=(4,4)) numpy_val1 = rng.multinomial(1, [0.1]*10, size=(4,4))
...@@ -153,11 +153,12 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -153,11 +153,12 @@ class T_SharedRandomStreams(unittest.TestCase):
# On matrices, for each row, the elements of that row should be shuffled. # On matrices, for each row, the elements of that row should be shuffled.
# Note that this differs from numpy.random.shuffle, where all the elements # Note that this differs from numpy.random.shuffle, where all the elements
# of the matrix are shuffled. # of the matrix are shuffled.
random = RandomStreams(234) random = RandomStreams(utt.fetch_seed())
m_input = tensor.dmatrix() m_input = tensor.dmatrix()
f = function([m_input], random.shuffle_row_elements(m_input), updates=random.updates()) f = function([m_input], random.shuffle_row_elements(m_input), updates=random.updates())
val_rng = numpy.random.RandomState(unittest_tools.fetch_seed()) # Generate the elements to be shuffled
val_rng = numpy.random.RandomState(utt.fetch_seed()+42)
in_mval = val_rng.uniform(-2, 2, size=(20,5)) in_mval = val_rng.uniform(-2, 2, size=(20,5))
fn_mval0 = f(in_mval) fn_mval0 = f(in_mval)
fn_mval1 = f(in_mval) fn_mval1 = f(in_mval)
...@@ -168,7 +169,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -168,7 +169,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert not numpy.all(in_mval == fn_mval1) assert not numpy.all(in_mval == fn_mval1)
assert not numpy.all(fn_mval0 == fn_mval1) assert not numpy.all(fn_mval0 == fn_mval1)
rng_seed = numpy.random.RandomState(234).randint(2**30) rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) rng = numpy.random.RandomState(int(rng_seed))
numpy_mval0 = in_mval.copy() numpy_mval0 = in_mval.copy()
numpy_mval1 = in_mval.copy() numpy_mval1 = in_mval.copy()
...@@ -182,7 +183,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -182,7 +183,7 @@ class T_SharedRandomStreams(unittest.TestCase):
# On vectors, the behaviour is the same as numpy.random.shuffle, # On vectors, the behaviour is the same as numpy.random.shuffle,
# except that it does not work in place, but returns a shuffled vector. # except that it does not work in place, but returns a shuffled vector.
random1 = RandomStreams(234) random1 = RandomStreams(utt.fetch_seed())
v_input = tensor.dvector() v_input = tensor.dvector()
f1 = function([v_input], random1.shuffle_row_elements(v_input)) f1 = function([v_input], random1.shuffle_row_elements(v_input))
...@@ -203,7 +204,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -203,7 +204,7 @@ class T_SharedRandomStreams(unittest.TestCase):
def test_default_updates(self): def test_default_updates(self):
# Basic case: default_updates # Basic case: default_updates
random_a = RandomStreams(234) random_a = RandomStreams(utt.fetch_seed())
out_a = random_a.uniform((2,2)) out_a = random_a.uniform((2,2))
fn_a = function([], out_a) fn_a = function([], out_a)
fn_a_val0 = fn_a() fn_a_val0 = fn_a()
...@@ -214,7 +215,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -214,7 +215,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(abs(nearly_zeros()) < 1e-5) assert numpy.all(abs(nearly_zeros()) < 1e-5)
# Explicit updates #1 # Explicit updates #1
random_b = RandomStreams(234) random_b = RandomStreams(utt.fetch_seed())
out_b = random_b.uniform((2,2)) out_b = random_b.uniform((2,2))
fn_b = function([], out_b, updates=random_b.updates()) fn_b = function([], out_b, updates=random_b.updates())
fn_b_val0 = fn_b() fn_b_val0 = fn_b()
...@@ -223,7 +224,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -223,7 +224,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_b_val1 == fn_a_val1) assert numpy.all(fn_b_val1 == fn_a_val1)
# Explicit updates #2 # Explicit updates #2
random_c = RandomStreams(234) random_c = RandomStreams(utt.fetch_seed())
out_c = random_c.uniform((2,2)) out_c = random_c.uniform((2,2))
fn_c = function([], out_c, updates=[out_c.update]) fn_c = function([], out_c, updates=[out_c.update])
fn_c_val0 = fn_c() fn_c_val0 = fn_c()
...@@ -232,7 +233,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -232,7 +233,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_c_val1 == fn_a_val1) assert numpy.all(fn_c_val1 == fn_a_val1)
# No updates at all # No updates at all
random_d = RandomStreams(234) random_d = RandomStreams(utt.fetch_seed())
out_d = random_d.uniform((2,2)) out_d = random_d.uniform((2,2))
fn_d = function([], out_d, no_default_updates=True) fn_d = function([], out_d, no_default_updates=True)
fn_d_val0 = fn_d() fn_d_val0 = fn_d()
...@@ -241,7 +242,7 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -241,7 +242,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_d_val1 == fn_d_val0) assert numpy.all(fn_d_val1 == fn_d_val0)
# No updates for out # No updates for out
random_e = RandomStreams(234) random_e = RandomStreams(utt.fetch_seed())
out_e = random_e.uniform((2,2)) out_e = random_e.uniform((2,2))
fn_e = function([], out_e, no_default_updates=[out_e.rng]) fn_e = function([], out_e, no_default_updates=[out_e.rng])
fn_e_val0 = fn_e() fn_e_val0 = fn_e()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论