Unverified 提交 8a7356ce authored 作者: Etienne Duchesne's avatar Etienne Duchesne 提交者: GitHub

Implement faster Multinomial JAX dispatch (#1316)

上级 2e9d502f
from functools import singledispatch
import jax
import jax.numpy as jnp
import numpy as np
from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
......@@ -394,16 +395,35 @@ def jax_sample_fn_binomial(op, node):
@jax_sample_fn.register(ptr.MultinomialRV)
def jax_sample_fn_multinomial(op, node):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)
from numpyro.distributions.util import multinomial
def sample_fn(rng_key, size, dtype, n, p):
sample = multinomial(key=rng_key, n=n, p=p, shape=size)
if size is not None:
n = jnp.broadcast_to(n, size)
p = jnp.broadcast_to(p, size + jnp.shape(p)[-1:])
else:
broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...]
sampling_rng = jax.random.split(rng_key, binom_p.shape[0])
def _binomial_sample_fn(carry, p_rng):
s, rho = carry
p, rng = p_rng
samples = jax.random.binomial(rng, s, p / rho)
s = s - samples
rho = rho - p
return ((s, rho), samples)
(remain, _), samples = jax.lax.scan(
_binomial_sample_fn,
(n.astype(np.float64), jnp.ones(binom_p.shape[1:])),
(binom_p, sampling_rng),
)
sample = jnp.concatenate(
[jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1
)
return sample
return sample_fn
......
......@@ -703,14 +703,15 @@ def test_beta_binomial():
)
@pytest.mark.skipif(
not numpyro_available, reason="Multinomial dispatch requires numpyro"
)
def test_multinomial():
rng = shared(np.random.default_rng(123))
# test with 'size' argument and n.shape == p.shape[:-1]
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
size = (10_000, 2)
g = pt.random.multinomial(n, p, size=size, rng=rng)
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
......@@ -718,6 +719,20 @@ def test_multinomial():
samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1
)
# test with no 'size' argument and no static shape
n = np.broadcast_to(np.array([10, 40]), size)
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
pt_n = pt.matrix("n")
pt_p = pt.matrix("p")
g = pt.random.multinomial(pt_n, pt_p, rng=rng, size=None)
g_fn = compile_random_function([pt_n, pt_p], g, mode="JAX")
samples = g_fn(n, p)
np.testing.assert_allclose(samples.mean(axis=0), n[0, :, None] * p, rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1
)
@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
def test_vonmises_mu_outside_circle():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论