提交 1f320096 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add tests for dtype parameter of random functions.

上级 fea13d01
...@@ -648,6 +648,25 @@ class T_RandomStreams(unittest.TestCase): ...@@ -648,6 +648,25 @@ class T_RandomStreams(unittest.TestCase):
assert numpy.all(val2 == numpy_val2) assert numpy.all(val2 == numpy_val2)
self.assertRaises(ValueError, made.g, n_val[:-1], pvals_val[:-1]) self.assertRaises(ValueError, made.g, n_val[:-1], pvals_val[:-1])
def test_dtype(self):
m = Module()
m.random = RandomStreams(utt.fetch_seed())
low = tensor.lscalar()
high = tensor.lscalar()
out = m.random.random_integers(low=low, high=high, size=(20,), dtype='int8')
assert out.dtype == 'int8'
m.f = Method([low, high], out)
made = m.make()
made.random.initialize()
val0 = made.f(0, 9)
assert val0.dtype == 'int8'
val1 = made.f(255, 257)
assert val1.dtype == 'int8'
assert numpy.all(abs(val1) <= 1)
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main from theano.tests import main
main("test_randomstreams") main("test_randomstreams")
...@@ -682,6 +682,23 @@ class T_random_function(unittest.TestCase): ...@@ -682,6 +682,23 @@ class T_random_function(unittest.TestCase):
assert numpy.all(val2 == numpy_val2) assert numpy.all(val2 == numpy_val2)
self.assertRaises(ValueError, g, rng2, n_val[:-1], pvals_val[:-1]) self.assertRaises(ValueError, g, rng2, n_val[:-1], pvals_val[:-1])
def test_dtype(self):
rng_R = random_state_type()
low = tensor.lscalar()
high = tensor.lscalar()
post_r, out = random_integers(rng_R, low=low, high=high, size=(20,), dtype='int8')
assert out.dtype == 'int8'
f = compile.function([rng_R, low, high], [post_r, out])
rng = numpy.random.RandomState(utt.fetch_seed())
rng0, val0 = f(rng, 0, 9)
assert val0.dtype == 'int8'
rng1, val1 = f(rng0, 255, 257)
assert val1.dtype == 'int8'
assert numpy.all(abs(val1) <= 1)
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main from theano.tests import main
main("test_raw_random") main("test_raw_random")
...@@ -586,6 +586,22 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -586,6 +586,22 @@ class T_SharedRandomStreams(unittest.TestCase):
assert numpy.all(val2 == numpy_val2) assert numpy.all(val2 == numpy_val2)
self.assertRaises(ValueError, g, n_val[:-1], pvals_val[:-1]) self.assertRaises(ValueError, g, n_val[:-1], pvals_val[:-1])
def test_dtype(self):
random = RandomStreams(utt.fetch_seed())
low = tensor.lscalar()
high = tensor.lscalar()
out = random.random_integers(low=low, high=high, size=(20,), dtype='int8')
assert out.dtype == 'int8'
f = function([low, high], out)
val0 = f(0, 9)
assert val0.dtype == 'int8'
val1 = f(255, 257)
assert val1.dtype == 'int8'
assert numpy.all(abs(val1) <= 1)
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main from theano.tests import main
main("test_shared_randomstreams") main("test_shared_randomstreams")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论