提交 17a5e424 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix JAX implementation of Categorical

上级 e9b56ae3
......@@ -169,6 +169,20 @@ def jax_sample_fn_loc_scale(op):
@jax_sample_fn.register(ptr.BernoulliRV)
def jax_sample_fn_bernoulli(op):
"""JAX implementation of `BernoulliRV`."""
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
def sample_fn(rng, size, dtype, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax.random.bernoulli(sampling_key, p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn
@jax_sample_fn.register(ptr.CategoricalRV)
def jax_sample_fn_no_dtype(op):
"""Generic JAX implementation of random variables."""
......
......@@ -595,8 +595,16 @@ def test_random_categorical():
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
assert samples.shape == (10000, 4)
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
# Test zero probabilities
g = pt.random.categorical([0, 0.5, 0, 0.5], size=(1000,), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
assert samples.shape == (1000,)
assert np.all(samples % 2 == 1)
def test_random_permutation():
array = np.arange(4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论