提交 539550b7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Test the default_update behavior of shared random streams.

上级 e11e3472
...@@ -201,6 +201,55 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -201,6 +201,55 @@ class T_SharedRandomStreams(unittest.TestCase):
self.assertRaises(TypeError, f1, in_mval) self.assertRaises(TypeError, f1, in_mval)
self.assertRaises(TypeError, f, in_vval) self.assertRaises(TypeError, f, in_vval)
def test_default_updates(self):
# Basic case: default_updates
random_a = RandomStreams(234)
out_a = random_a.uniform((2,2))
fn_a = function([], out_a)
fn_a_val0 = fn_a()
fn_a_val1 = fn_a()
assert not numpy.all(fn_a_val0 == fn_a_val1)
nearly_zeros = function([], out_a + out_a - 2 * out_a)
assert numpy.all(abs(nearly_zeros()) < 1e-5)
# Explicit updates #1
random_b = RandomStreams(234)
out_b = random_b.uniform((2,2))
fn_b = function([], out_b, updates=random_b.updates())
fn_b_val0 = fn_b()
fn_b_val1 = fn_b()
assert numpy.all(fn_b_val0 == fn_a_val0)
assert numpy.all(fn_b_val1 == fn_a_val1)
# Explicit updates #2
random_c = RandomStreams(234)
out_c = random_c.uniform((2,2))
fn_c = function([], out_c, updates=[out_c.update])
fn_c_val0 = fn_c()
fn_c_val1 = fn_c()
assert numpy.all(fn_c_val0 == fn_a_val0)
assert numpy.all(fn_c_val1 == fn_a_val1)
# No updates at all
random_d = RandomStreams(234)
out_d = random_d.uniform((2,2))
fn_d = function([], out_d, no_default_updates=True)
fn_d_val0 = fn_d()
fn_d_val1 = fn_d()
assert numpy.all(fn_d_val0 == fn_a_val0)
assert numpy.all(fn_d_val1 == fn_d_val0)
# No updates for out
random_e = RandomStreams(234)
out_e = random_e.uniform((2,2))
fn_e = function([], out_e, no_default_updates=[out_e.rng])
fn_e_val0 = fn_e()
fn_e_val1 = fn_e()
assert numpy.all(fn_e_val0 == fn_a_val0)
assert numpy.all(fn_e_val1 == fn_e_val0)
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main from theano.tests import main
main("test_randomstreams") main("test_randomstreams")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论