提交 51210c39 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Extend supported RandomVariables in JAX backend via NumPyro

Dependency is optional
上级 dcd24a36
......@@ -117,7 +117,7 @@ jobs:
run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
......
......@@ -3,6 +3,7 @@ import warnings
import jax.numpy as jnp
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.basic import infer_static_shape
from pytensor.tensor.extra_ops import (
Bartlett,
BroadcastTo,
......@@ -102,8 +103,12 @@ def jax_funcify_RavelMultiIndex(op, **kwargs):
@jax_funcify.register(BroadcastTo)
def jax_funcify_BroadcastTo(op, **kwargs):
def jax_funcify_BroadcastTo(op, node, **kwargs):
shape = node.inputs[1:]
static_shape = infer_static_shape(shape)[1]
def broadcast_to(x, *shape):
shape = tuple(st if st is not None else s for s, st in zip(shape, static_shape))
return jnp.broadcast_to(x, shape)
return broadcast_to
......
from functools import singledispatch
import jax
import numpy as np
from numpy.random import Generator, RandomState
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
_coerce_to_uint32_array,
......@@ -12,6 +13,13 @@ from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.shape import Shape, Shape_i
try:
import numpyro # noqa: F401
numpyro_available = True
except ImportError:
numpyro_available = False
numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3}
......@@ -83,11 +91,8 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
out_dtype = rv.type.dtype
out_size = rv.type.shape
if isinstance(op, aer.MvNormalRV):
# PyTensor sets the `size` to the concatenation of the support shape
# and the batch shape, while JAX explicitly requires the batch
# shape only for the multivariate normal.
out_size = node.outputs[1].type.shape[:-1]
if op.ndim_supp > 0:
out_size = node.outputs[1].type.shape[: -op.ndim_supp]
# If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is
......@@ -292,3 +297,75 @@ def jax_sample_fn_permutation(op):
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.BinomialRV)
def jax_sample_fn_binomial(op):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)
from numpyro.distributions.util import binomial
def sample_fn(rng, size, dtype, n, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = binomial(key=sampling_key, n=n, p=p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.MultinomialRV)
def jax_sample_fn_multinomial(op):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)
from numpyro.distributions.util import multinomial
def sample_fn(rng, size, dtype, n, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.VonMisesRV)
def jax_sample_fn_vonmises(op):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)
from numpyro.distributions.util import von_mises_centered
def sample_fn(rng, size, dtype, mu, kappa):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = von_mises_centered(
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
)
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn
......@@ -2,10 +2,11 @@ from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.tensor import abs as abs_t
from pytensor.tensor import exp, floor, log, log1p, reciprocal, sqrt
from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt
from pytensor.tensor.basic import MakeVector, cast, ones_like, switch, zeros_like
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
BetaBinomialRV,
ChiSquareRV,
GenGammaRV,
GeometricRV,
......@@ -14,6 +15,8 @@ from pytensor.tensor.random.basic import (
LogNormalRV,
NegBinomialRV,
WaldRV,
beta,
binomial,
gamma,
normal,
poisson,
......@@ -133,6 +136,15 @@ def wald_from_normal_uniform(fgraph, node):
return [next_rng, cast(w, dtype=node.default_output().dtype)]
@node_rewriter([BetaBinomialRV])
def beta_binomial_from_beta_binomial(fgraph, node):
rng, *other_inputs, n, a, b = node.inputs
n, a, b = broadcast_arrays(n, a, b)
next_rng, b = beta.make_node(rng, *other_inputs, a, b).outputs
next_rng, b = binomial.make_node(next_rng, *other_inputs, n, b).outputs
return [next_rng, b]
random_vars_opt = SequenceDB()
random_vars_opt.register(
"lognormal_from_normal",
......@@ -174,6 +186,11 @@ random_vars_opt.register(
in2out(wald_from_normal_uniform),
"jax",
)
random_vars_opt.register(
"beta_binomial_from_beta_binomial",
in2out(beta_binomial_from_beta_binomial),
"jax",
)
optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110)
optdb.register(
......
......@@ -19,6 +19,9 @@ from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_val
jax = pytest.importorskip("jax")
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
def test_random_RandomStream():
"""Two successive calls of a compiled graph using `RandomStream` should
return different values.
......@@ -377,6 +380,25 @@ def test_random_updates(rng_ctor):
# https://stackoverflow.com/a/48603469
lambda mean, scale: (mean / scale, 0, scale),
),
pytest.param(
aer.vonmises,
[
set_test_value(
at.dvector(),
np.array([-0.5, 1.3], dtype=np.float64),
),
set_test_value(
at.dvector(),
np.array([5.5, 13.0], dtype=np.float64),
),
],
(2,),
"vonmises",
lambda mu, kappa: (kappa, mu),
marks=pytest.mark.skipif(
not numpyro_available, reason="VonMises dispatch requires numpyro"
),
),
],
)
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):
......@@ -519,6 +541,83 @@ def test_negative_binomial():
)
@pytest.mark.skipif(not numpyro_available, reason="Binomial dispatch requires numpyro")
def test_binomial():
rng = shared(np.random.RandomState(123))
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = at.random.binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
@pytest.mark.skipif(
not numpyro_available, reason="BetaBinomial dispatch requires numpyro"
)
def test_beta_binomial():
rng = shared(np.random.RandomState(123))
n = np.array([10, 40])
a = np.array([1.5, 13])
b = np.array([0.5, 9])
g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0),
np.sqrt((n * a * b * (a + b + n)) / ((a + b) ** 2 * (a + b + 1))),
rtol=0.1,
)
@pytest.mark.skipif(
not numpyro_available, reason="Multinomial dispatch requires numpyro"
)
def test_multinomial():
rng = shared(np.random.RandomState(123))
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1
)
@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
def test_vonmises_mu_outside_circle():
# Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle
# We test that the random draws from the JAX dispatch work as expected in these cases
rng = shared(np.random.RandomState(123))
mu = np.array([-30, 40])
kappa = np.array([100, 10])
g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
)
# Circvar only does the correct thing in more recent versions of Scipy
# https://github.com/scipy/scipy/pull/5747
# np.testing.assert_allclose(
# stats.circvar(samples, axis=0),
# 1 - special.iv(1, kappa) / special.iv(0, kappa),
# rtol=0.1,
# )
# For now simple compare with std from numpy draws
rng = np.random.default_rng(123)
ref_samples = rng.vonmises(mu, kappa, size=(10_000, 2))
np.testing.assert_allclose(
np.std(samples, axis=0), np.std(ref_samples, axis=0), rtol=0.1
)
def test_random_unimplemented():
"""Compiling a graph with a non-supported `RandomVariable` should
raise an error.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论