提交 09270345 authored 作者: James Bergstra's avatar James Bergstra

Added shared randomstreams tutorial code as a unit test

上级 5c4ffc5b
......@@ -6,17 +6,32 @@ import numpy
from theano.tensor import raw_random
from theano.tensor.shared_randomstreams import RandomStreams
from theano.compile.pfunc import pfunc
from theano import function
from theano import tensor
from theano import compile, gof
class T_RandomStreams(unittest.TestCase):
def test_tutorial(self):
srng = RandomStreams(seed=234)
rv_u = srng.uniform((2,2))
rv_n = srng.normal((2,2))
f = function([], rv_u, updates=[rv_u.update])
g = function([], rv_n) #omitting rv_n.update
nearly_zeros = function([], rv_u + rv_u - 2 * rv_u, updates=[rv_u.update])
assert numpy.all(f() != f())
assert numpy.all(g() == g())
assert numpy.all(abs(nearly_zeros()) < 1e-5)
assert isinstance(rv_u.rng.value, numpy.random.RandomState)
def test_basics(self):
random = RandomStreams(234)
fn = pfunc([], random.uniform((2,2)), updates=random.updates())
gn = pfunc([], random.normal((2,2)), updates=random.updates())
fn = function([], random.uniform((2,2)), updates=random.updates())
gn = function([], random.normal((2,2)), updates=random.updates())
fn_val0 = fn()
fn_val1 = fn()
......@@ -40,7 +55,7 @@ class T_RandomStreams(unittest.TestCase):
def test_seed_fn(self):
random = RandomStreams(234)
fn = pfunc([], random.uniform((2,2)), updates=random.updates())
fn = function([], random.uniform((2,2)), updates=random.updates())
random.seed(888)
......@@ -62,7 +77,7 @@ class T_RandomStreams(unittest.TestCase):
random = RandomStreams(234)
out = random.uniform((2,2))
fn = pfunc([], out, updates=random.updates())
fn = function([], out, updates=random.updates())
random.seed(888)
......@@ -80,7 +95,7 @@ class T_RandomStreams(unittest.TestCase):
random = RandomStreams(234)
out = random.uniform((2,2))
fn = pfunc([], out, updates=random.updates())
fn = function([], out, updates=random.updates())
random.seed(888)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论