提交 0f5da80c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

More stable fix for JAX Multinomial

上级 d9b10859
......@@ -412,7 +412,9 @@ def jax_sample_fn_multinomial(op, node):
remaining_n, remaining_p = carry
p, rng = p_rng
samples = jnp.where(
p == 0, 0, jax.random.binomial(rng, remaining_n, p / remaining_p)
remaining_n == 0,
0,
jax.random.binomial(rng, remaining_n, p / remaining_p),
)
remaining_n -= samples
remaining_p -= p
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论