提交 4c9bcb9e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add test cases for bug fixed.

上级 8b2ad7df
...@@ -545,6 +545,42 @@ class T_random_function(utt.InferShapeTester): ...@@ -545,6 +545,42 @@ class T_random_function(utt.InferShapeTester):
self.assertRaises(ValueError, f, rng_state0, [4]) self.assertRaises(ValueError, f, rng_state0, [4])
self.assertRaises(ValueError, f, rng_state0, [4, 3, 4, 5]) self.assertRaises(ValueError, f, rng_state0, [4, 3, 4, 5])
def test_mixed_shape(self):
# Test when the provided shape is a tuple of ints and scalar vars
rng_R = random_state_type()
shape0 = tensor.lscalar()
shape = (shape0, 3)
post_r, u = uniform(rng_R, size=shape, ndim=2)
f = compile.function([rng_R, shape0], u)
rng_state0 = numpy.random.RandomState(utt.fetch_seed())
assert f(rng_state0, 2).shape == (2, 3)
assert f(rng_state0, 8).shape == (8, 3)
post_r, v = uniform(rng_R, size=shape)
g = compile.function([rng_R, shape0], v)
assert g(rng_state0, 2).shape == (2, 3)
assert g(rng_state0, 8).shape == (8, 3)
def test_mixed_shape_bcastable(self):
# Test when the provided shape is a tuple of ints and scalar vars
rng_R = random_state_type()
shape0 = tensor.lscalar()
shape = (shape0, 1)
post_r, u = uniform(rng_R, size=shape, ndim=2)
assert u.broadcastable == (False, True)
f = compile.function([rng_R, shape0], u)
rng_state0 = numpy.random.RandomState(utt.fetch_seed())
assert f(rng_state0, 2).shape == (2, 1)
assert f(rng_state0, 8).shape == (8, 1)
post_r, v = uniform(rng_R, size=shape)
assert v.broadcastable == (False, True)
g = compile.function([rng_R, shape0], v)
assert g(rng_state0, 2).shape == (2, 1)
assert g(rng_state0, 8).shape == (8, 1)
def test_default_shape(self): def test_default_shape(self):
rng_R = random_state_type() rng_R = random_state_type()
post_r, out = uniform(rng_R) post_r, out = uniform(rng_R)
......
...@@ -333,6 +333,36 @@ class T_SharedRandomStreams(unittest.TestCase): ...@@ -333,6 +333,36 @@ class T_SharedRandomStreams(unittest.TestCase):
self.assertRaises(ValueError, f, [4]) self.assertRaises(ValueError, f, [4])
self.assertRaises(ValueError, f, [4,3,4,5]) self.assertRaises(ValueError, f, [4,3,4,5])
def test_mixed_shape(self):
# Test when the provided shape is a tuple of ints and scalar vars
random = RandomStreams(utt.fetch_seed())
shape0 = tensor.lscalar()
shape = (shape0, 3)
f = function([shape0], random.uniform(size=shape, ndim=2))
assert f(2).shape == (2, 3)
assert f(8).shape == (8, 3)
g = function([shape0], random.uniform(size=shape))
assert g(2).shape == (2, 3)
assert g(8).shape == (8, 3)
def test_mixed_shape_bcastable(self):
# Test when the provided shape is a tuple of ints and scalar vars
random = RandomStreams(utt.fetch_seed())
shape0 = tensor.lscalar()
shape = (shape0, 1)
u = random.uniform(size=shape, ndim=2)
assert u.broadcastable == (False, True)
f = function([shape0], u)
assert f(2).shape == (2, 1)
assert f(8).shape == (8, 1)
v = random.uniform(size=shape)
assert v.broadcastable == (False, True)
g = function([shape0], v)
assert g(2).shape == (2, 1)
assert g(8).shape == (8, 1)
def test_default_shape(self): def test_default_shape(self):
random = RandomStreams(utt.fetch_seed()) random = RandomStreams(utt.fetch_seed())
f = function([], random.uniform()) f = function([], random.uniform())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论