Unverified 提交 3af923ba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix nan in jax implementation of Multinomial (#1328)

上级 0b56ed9c
...@@ -409,12 +409,14 @@ def jax_sample_fn_multinomial(op, node): ...@@ -409,12 +409,14 @@ def jax_sample_fn_multinomial(op, node):
sampling_rng = jax.random.split(rng_key, binom_p.shape[0]) sampling_rng = jax.random.split(rng_key, binom_p.shape[0])
def _binomial_sample_fn(carry, p_rng): def _binomial_sample_fn(carry, p_rng):
s, rho = carry remaining_n, remaining_p = carry
p, rng = p_rng p, rng = p_rng
samples = jax.random.binomial(rng, s, p / rho) samples = jnp.where(
s = s - samples p == 0, 0, jax.random.binomial(rng, remaining_n, p / remaining_p)
rho = rho - p )
return ((s, rho), samples) remaining_n -= samples
remaining_p -= p
return ((remaining_n, remaining_p), samples)
(remain, _), samples = jax.lax.scan( (remain, _), samples = jax.lax.scan(
_binomial_sample_fn, _binomial_sample_fn,
......
...@@ -733,6 +733,18 @@ def test_multinomial(): ...@@ -733,6 +733,18 @@ def test_multinomial():
samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1 samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1
) )
# Test with p=0
g = pt.random.multinomial(n=5, p=pt.eye(4))
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_array_equal(samples, np.eye(4) * 5)
# Test with n=0
g = pt.random.multinomial(n=0, p=np.ones(4) / 4)
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_array_equal(samples, np.zeros(4))
@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro") @pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
def test_vonmises_mu_outside_circle(): def test_vonmises_mu_outside_circle():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论