Unverified 提交 8a7356ce authored 作者: Etienne Duchesne's avatar Etienne Duchesne 提交者: GitHub

Implement faster Multinomial JAX dispatch (#1316)

上级 2e9d502f
from functools import singledispatch from functools import singledispatch
import jax import jax
import jax.numpy as jnp
import numpy as np import numpy as np
from numpy.random import Generator from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined] from numpy.random.bit_generator import ( # type: ignore[attr-defined]
...@@ -394,16 +395,35 @@ def jax_sample_fn_binomial(op, node): ...@@ -394,16 +395,35 @@ def jax_sample_fn_binomial(op, node):
@jax_sample_fn.register(ptr.MultinomialRV) @jax_sample_fn.register(ptr.MultinomialRV)
def jax_sample_fn_multinomial(op, node): def jax_sample_fn_multinomial(op, node):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)
from numpyro.distributions.util import multinomial
def sample_fn(rng_key, size, dtype, n, p): def sample_fn(rng_key, size, dtype, n, p):
sample = multinomial(key=rng_key, n=n, p=p, shape=size) if size is not None:
n = jnp.broadcast_to(n, size)
p = jnp.broadcast_to(p, size + jnp.shape(p)[-1:])
else:
broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...]
sampling_rng = jax.random.split(rng_key, binom_p.shape[0])
def _binomial_sample_fn(carry, p_rng):
s, rho = carry
p, rng = p_rng
samples = jax.random.binomial(rng, s, p / rho)
s = s - samples
rho = rho - p
return ((s, rho), samples)
(remain, _), samples = jax.lax.scan(
_binomial_sample_fn,
(n.astype(np.float64), jnp.ones(binom_p.shape[1:])),
(binom_p, sampling_rng),
)
sample = jnp.concatenate(
[jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1
)
return sample return sample
return sample_fn return sample_fn
......
...@@ -703,14 +703,15 @@ def test_beta_binomial(): ...@@ -703,14 +703,15 @@ def test_beta_binomial():
) )
@pytest.mark.skipif(
not numpyro_available, reason="Multinomial dispatch requires numpyro"
)
def test_multinomial(): def test_multinomial():
rng = shared(np.random.default_rng(123)) rng = shared(np.random.default_rng(123))
# test with 'size' argument and n.shape == p.shape[:-1]
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) size = (10_000, 2)
g = pt.random.multinomial(n, p, size=size, rng=rng)
g_fn = compile_random_function([], g, mode="JAX") g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
...@@ -718,6 +719,20 @@ def test_multinomial(): ...@@ -718,6 +719,20 @@ def test_multinomial():
samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1 samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1
) )
# test with no 'size' argument and no static shape
n = np.broadcast_to(np.array([10, 40]), size)
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
pt_n = pt.matrix("n")
pt_p = pt.matrix("p")
g = pt.random.multinomial(pt_n, pt_p, rng=rng, size=None)
g_fn = compile_random_function([pt_n, pt_p], g, mode="JAX")
samples = g_fn(n, p)
np.testing.assert_allclose(samples.mean(axis=0), n[0, :, None] * p, rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1
)
@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro") @pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
def test_vonmises_mu_outside_circle(): def test_vonmises_mu_outside_circle():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论