提交 43e62d0f authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

added test for choice function

上级 f8e75bd5
......@@ -182,6 +182,28 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_choice(self):
"""Test that RandomStreams.choice generates the same results as numpy"""
# numpy.random.choice is only available for numpy versions >= 1.7
major, minor, _ = numpy.version.short_version.split('.')
if (int(major), int(minor)) < (1, 7):
raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
'(%s)' % numpy.__version__)
# Check over two calls to see if the random state is correctly updated.
random = RandomStreams(utt.fetch_seed())
fn = function([], random.choice((11, 8), 10, 1, 0))
fn_val0 = fn()
fn_val1 = fn()
rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
numpy_val0 = rng.choice(10, (11, 8), True, None)
numpy_val1 = rng.choice(10, (11, 8), True, None)
assert numpy.all(fn_val0 == numpy_val0)
assert numpy.all(fn_val1 == numpy_val1)
def test_permutation(self):
"""Test that RandomStreams.permutation generates the same results as numpy"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论