提交 8c256357 authored 作者: James Bergstra's avatar James Bergstra

multinomial - tests for examples in doc, tests new broadcasting behaviour

上级 25b4ce8a
......@@ -732,6 +732,52 @@ class T_random_function(unittest.TestCase):
assert numpy.all(val2 == numpy_val2)
self.assertRaises(ValueError, g, rng2, n_val[:-1], pvals_val[:-1])
def test_multinomial_tensor3_a(self):
# Test the examples given in the multinomial documentation regarding
# tensor3 objects
rng_R = random_state_type()
n = 9
pvals = tensor.dtensor3()
post_r, out = multinomial(rng_R, n=n, pvals=pvals, size=(1,-1))
assert out.ndim == 3
assert out.broadcastable==(True, False, False)
f = compile.function([rng_R, pvals], [post_r, out], accept_inplace=True)
rng = numpy.random.RandomState(utt.fetch_seed())
numpy_rng = numpy.random.RandomState(utt.fetch_seed())
pvals_val = numpy.asarray([[[.1, .9], [.2, .8], [.3, .7]]])
assert pvals_val.shape == (1, 3, 2)
new_rng, draw = f(rng, pvals_val)
assert draw.shape==(1,3,2)
assert numpy.allclose(draw.sum(axis=2), 9)
def test_multinomial_tensor3_b(self):
# Test the examples given in the multinomial documentation regarding
# tensor3 objects
rng_R = random_state_type()
n = 9
pvals = tensor.dtensor3()
post_r, out = multinomial(rng_R, n=n, pvals=pvals, size=(10, 1,-1))
assert out.ndim == 4
assert out.broadcastable==(False, True, False, False)
f = compile.function([rng_R, pvals], [post_r, out], accept_inplace=True)
rng = numpy.random.RandomState(utt.fetch_seed())
numpy_rng = numpy.random.RandomState(utt.fetch_seed())
pvals_val = numpy.asarray([[[.1, .9], [.2, .8], [.3, .7]]])
assert pvals_val.shape == (1, 3, 2)
out_rng, draw = f(rng, pvals_val)
assert draw.shape==(10,1,3,2)
assert numpy.allclose(draw.sum(axis=3), 9)
def test_dtype(self):
rng_R = random_state_type()
low = tensor.lscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论