提交 19b0cf16 authored 作者: James Bergstra's avatar James Bergstra

added tests for the borrow argument on SharedRandomStream objects

上级 3a260ce2
......@@ -6,7 +6,7 @@ import numpy
from theano.tensor import raw_random
from theano.tensor.shared_randomstreams import RandomStreams
from theano import function
from theano import function, shared
from theano import tensor
from theano import compile, config, gof
......@@ -611,6 +611,85 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(abs(val1) <= 1)
def test_shared_constructor_borrow(self):
rng = numpy.random.RandomState(123)
s_rng_default = shared(rng)
s_rng_True = shared(rng, borrow=True)
s_rng_False = shared(rng, borrow=False)
# test borrow contract: that False means a copy must have been made
assert s_rng_default.container.storage[0] is not rng
assert s_rng_False.container.storage[0] is not rng
# test current implementation: that True means a copy was not made
assert s_rng_True.container.storage[0] is rng
# ensure that all the random number generators are in the same state
v = rng.randn()
v0 = s_rng_default.container.storage[0].randn()
v1 = s_rng_False.container.storage[0].randn()
assert v == v0 == v1
def test_get_value_borrow(self):
rng = numpy.random.RandomState(123)
s_rng = shared(rng)
r_ = s_rng.container.storage[0]
r_T = s_rng.get_value(borrow=True)
r_F = s_rng.get_value(borrow=False)
#the contract requires that borrow=False returns a copy
assert r_ is not r_F
# the current implementation allows for True to return the real thing
assert r_ is r_T
#either way, the rngs should all be in the same state
assert r_.rand() == r_F.rand()
def test_get_value_internal_type(self):
rng = numpy.random.RandomState(123)
s_rng = shared(rng)
# there is no special behaviour required of return_internal_type
# this test just ensures that the flag doesn't screw anything up
# by repeating the get_value_borrow test.
r_ = s_rng.container.storage[0]
r_T = s_rng.get_value(borrow=True, return_internal_type=True)
r_F = s_rng.get_value(borrow=False, return_internal_type=True)
#the contract requires that borrow=False returns a copy
assert r_ is not r_F
# the current implementation allows for True to return the real thing
assert r_ is r_T
#either way, the rngs should all be in the same state
assert r_.rand() == r_F.rand()
def test_set_value_borrow(self):
rng = numpy.random.RandomState(123)
s_rng = shared(rng)
new_rng = numpy.random.RandomState(234234)
# Test the borrow contract is respected:
# assigning with borrow=False makes a copy
s_rng.set_value(new_rng, borrow=False)
assert new_rng is not s_rng.container.storage[0]
assert new_rng.randn() == s_rng.container.storage[0].randn()
# Test that the current implementation is actually borrowing when it can.
rr = numpy.random.RandomState(33)
s_rng.set_value(rr, borrow=True)
assert rr is s_rng.container.storage[0]
if __name__ == '__main__':
from theano.tests import main
main("test_shared_randomstreams")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论