提交 e9b56ae3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add more tests for JAX implementation of ChoiceRV

上级 36b2ac9e
......@@ -545,27 +545,49 @@ def test_random_dirichlet(parameter, size):
def test_random_choice():
# Elements are picked at equal frequency
num_samples = 10000
# `replace=True` and `p is None`
rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(4), size=num_samples, rng=rng)
g = pt.random.choice(np.arange(4), size=10_000, rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
assert samples.shape == (10_000,)
# Elements are picked at equal frequency
np.testing.assert_allclose(np.mean(samples == 3), 0.25, 2)
# `replace=True` and `p is not None`
rng = shared(np.random.default_rng(123))
g = pt.random.choice(4, p=np.array([0.0, 0.5, 0.0, 0.5]), size=(5, 2), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
assert samples.shape == (5, 2)
# Only odd numbers are picked
assert np.all(samples % 2 == 1)
# `replace=False` produces unique results
# `replace=False` and `p is None`
rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(100), replace=False, size=99, rng=rng)
g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
assert len(np.unique(samples)) == 99
assert samples.shape == (2, 49)
# Elements are unique
assert len(np.unique(samples)) == 98
# We can pass an array with probabilities
# `replace=False` and `p is not None`
rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
g = pt.random.choice(
8,
p=np.array([0.25, 0, 0.25, 0, 0.25, 0, 0.25, 0]),
size=3,
rng=rng,
replace=False,
)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples, np.zeros(10))
assert samples.shape == (3,)
# Elements are unique
assert len(np.unique(samples)) == 3
# Only even numbers are picked
assert np.all(samples % 2 == 0)
def test_random_categorical():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论