Unverified 提交 3af923ba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix nan in jax implementation of Multinomial (#1328)

上级 0b56ed9c
......@@ -409,12 +409,14 @@ def jax_sample_fn_multinomial(op, node):
sampling_rng = jax.random.split(rng_key, binom_p.shape[0])
def _binomial_sample_fn(carry, p_rng):
s, rho = carry
remaining_n, remaining_p = carry
p, rng = p_rng
samples = jax.random.binomial(rng, s, p / rho)
s = s - samples
rho = rho - p
return ((s, rho), samples)
samples = jnp.where(
p == 0, 0, jax.random.binomial(rng, remaining_n, p / remaining_p)
)
remaining_n -= samples
remaining_p -= p
return ((remaining_n, remaining_p), samples)
(remain, _), samples = jax.lax.scan(
_binomial_sample_fn,
......
......@@ -733,6 +733,18 @@ def test_multinomial():
samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1
)
# Test with p=0
g = pt.random.multinomial(n=5, p=pt.eye(4))
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_array_equal(samples, np.eye(4) * 5)
# Test with n=0
g = pt.random.multinomial(n=0, p=np.ones(4) / 4)
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_array_equal(samples, np.zeros(4))
@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
def test_vonmises_mu_outside_circle():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论