提交 767a7312 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use utt.fetch_seed() instead of a fixed seed in test_sharedrandomstreams

上级 b90ec1d2
......@@ -2,7 +2,7 @@ __docformat__ = "restructuredtext en"
import sys
import unittest
import numpy
import numpy
from theano.tensor import raw_random
from theano.tensor.shared_randomstreams import RandomStreams
......@@ -11,7 +11,7 @@ from theano import function
from theano import tensor
from theano import compile, gof
from theano.tests import unittest_tools
from theano.tests import unittest_tools as utt
class T_SharedRandomStreams(unittest.TestCase):
......@@ -30,7 +30,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert isinstance(rv_u.rng.value, numpy.random.RandomState)
def test_basics(self):
random = RandomStreams(234)
random = RandomStreams(utt.fetch_seed())
fn = function([], random.uniform((2,2)), updates=random.updates())
gn = function([], random.normal((2,2)), updates=random.updates())
......@@ -39,7 +39,7 @@ class T_SharedRandomStreams(unittest.TestCase):
gn_val0 = gn()
rng_seed = numpy.random.RandomState(234).randint(2**30)
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0
......@@ -58,12 +58,12 @@ class T_SharedRandomStreams(unittest.TestCase):
random = RandomStreams(234)
fn = function([], random.uniform((2,2)), updates=random.updates())
random.seed(888)
random.seed(utt.fetch_seed())
fn_val0 = fn()
fn_val1 = fn()
rng_seed = numpy.random.RandomState(888).randint(2**30)
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0
......@@ -80,7 +80,7 @@ class T_SharedRandomStreams(unittest.TestCase):
out = random.uniform((2,2))
fn = function([], out, updates=random.updates())
random.seed(888)
random.seed(utt.fetch_seed())
rng = numpy.random.RandomState()
rng.set_state(random[out.rng].get_state()) #tests getitem
......@@ -100,8 +100,8 @@ class T_SharedRandomStreams(unittest.TestCase):
random.seed(888)
rng = numpy.random.RandomState(823874)
random[out.rng] = numpy.random.RandomState(823874)
rng = numpy.random.RandomState(utt.fetch_seed())
random[out.rng] = numpy.random.RandomState(utt.fetch_seed())
fn_val0 = fn()
fn_val1 = fn()
......@@ -111,15 +111,15 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_val1 == numpy_val1)
def test_permutation(self):
"""Test that RandomStreams.uniform generates the same results as numpy"""
"""Test that RandomStreams.permutation generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated.
random = RandomStreams(234)
random = RandomStreams(utt.fetch_seed())
fn = function([], random.permutation((20,), 10), updates=random.updates())
fn_val0 = fn()
fn_val1 = fn()
rng_seed = numpy.random.RandomState(234).randint(2**30)
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
# rng.permutation outputs one vector at a time, so we iterate.
......@@ -132,13 +132,13 @@ class T_SharedRandomStreams(unittest.TestCase):
def test_multinomial(self):
"""Test that RandomStreams.multinomial generates the same results as numpy"""
# Check over two calls to see if the random state is correctly updated.
random = RandomStreams(234)
random = RandomStreams(utt.fetch_seed())
fn = function([], random.multinomial((4,4), 1, [0.1]*10), updates=random.updates())
fn_val0 = fn()
fn_val1 = fn()
rng_seed = numpy.random.RandomState(234).randint(2**30)
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
numpy_val0 = rng.multinomial(1, [0.1]*10, size=(4,4))
numpy_val1 = rng.multinomial(1, [0.1]*10, size=(4,4))
......@@ -153,11 +153,12 @@ class T_SharedRandomStreams(unittest.TestCase):
# On matrices, for each row, the elements of that row should be shuffled.
# Note that this differs from numpy.random.shuffle, where all the elements
# of the matrix are shuffled.
random = RandomStreams(234)
random = RandomStreams(utt.fetch_seed())
m_input = tensor.dmatrix()
f = function([m_input], random.shuffle_row_elements(m_input), updates=random.updates())
val_rng = numpy.random.RandomState(unittest_tools.fetch_seed())
# Generate the elements to be shuffled
val_rng = numpy.random.RandomState(utt.fetch_seed()+42)
in_mval = val_rng.uniform(-2, 2, size=(20,5))
fn_mval0 = f(in_mval)
fn_mval1 = f(in_mval)
......@@ -168,7 +169,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert not numpy.all(in_mval == fn_mval1)
assert not numpy.all(fn_mval0 == fn_mval1)
rng_seed = numpy.random.RandomState(234).randint(2**30)
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed))
numpy_mval0 = in_mval.copy()
numpy_mval1 = in_mval.copy()
......@@ -182,7 +183,7 @@ class T_SharedRandomStreams(unittest.TestCase):
# On vectors, the behaviour is the same as numpy.random.shuffle,
# except that it does not work in place, but returns a shuffled vector.
random1 = RandomStreams(234)
random1 = RandomStreams(utt.fetch_seed())
v_input = tensor.dvector()
f1 = function([v_input], random1.shuffle_row_elements(v_input))
......@@ -203,7 +204,7 @@ class T_SharedRandomStreams(unittest.TestCase):
def test_default_updates(self):
# Basic case: default_updates
random_a = RandomStreams(234)
random_a = RandomStreams(utt.fetch_seed())
out_a = random_a.uniform((2,2))
fn_a = function([], out_a)
fn_a_val0 = fn_a()
......@@ -214,7 +215,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(abs(nearly_zeros()) < 1e-5)
# Explicit updates #1
random_b = RandomStreams(234)
random_b = RandomStreams(utt.fetch_seed())
out_b = random_b.uniform((2,2))
fn_b = function([], out_b, updates=random_b.updates())
fn_b_val0 = fn_b()
......@@ -223,7 +224,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_b_val1 == fn_a_val1)
# Explicit updates #2
random_c = RandomStreams(234)
random_c = RandomStreams(utt.fetch_seed())
out_c = random_c.uniform((2,2))
fn_c = function([], out_c, updates=[out_c.update])
fn_c_val0 = fn_c()
......@@ -232,7 +233,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_c_val1 == fn_a_val1)
# No updates at all
random_d = RandomStreams(234)
random_d = RandomStreams(utt.fetch_seed())
out_d = random_d.uniform((2,2))
fn_d = function([], out_d, no_default_updates=True)
fn_d_val0 = fn_d()
......@@ -241,7 +242,7 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_d_val1 == fn_d_val0)
# No updates for out
random_e = RandomStreams(234)
random_e = RandomStreams(utt.fetch_seed())
out_e = random_e.uniform((2,2))
fn_e = function([], out_e, no_default_updates=[out_e.rng])
fn_e_val0 = fn_e()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论