提交 d647578c authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Thomas Wiecki

Refactor JAX implementations of `RandomVariable`

上级 054ad0fd
from functools import singledispatch
import jax
import jax.numpy as jnp
from numpy.random import Generator, RandomState
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
_coerce_to_uint32_array,
)
import pytensor.tensor.random.basic as aer
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.shape import Shape
numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3}
......@@ -46,25 +48,189 @@ def jax_typify_Generator(rng, **kwargs):
return state
@jax_funcify.register(RandomVariable)
@jax_funcify.register(aer.RandomVariable)
def jax_funcify_RandomVariable(op, node, **kwargs):
name = op.name
"""JAX implementation of random variables."""
rv = node.outputs[1]
out_dtype = rv.type.dtype
out_size = rv.type.shape
# TODO Make sure there's a 1-to-1 correspondance with names
if not hasattr(jax.random, name):
if isinstance(op, aer.MvNormalRV):
# PyTensor sets the `size` to the concatenation of the support shape
# and the batch shape, while JAX explicitly requires the batch
# shape only for the multivariate normal.
out_size = node.outputs[1].type.shape[:-1]
# If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is
# not and we fail gracefully.
if None in out_size:
if not isinstance(node.inputs[1].owner.op, Shape):
raise NotImplementedError(
f"No JAX conversion for the given distribution: {name}"
"""JAX random variables require concrete values for the `size` parameter of the distributions.
Concrete values are either constants, or the shape of an array.
"""
)
dtype = node.outputs[1].dtype
def sample_fn(rng, size, dtype, *parameters):
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
else:
def sample_fn(rng, size, dtype, *parameters):
return jax_sample_fn(op)(rng, out_size, out_dtype, *parameters)
return sample_fn
@singledispatch
def jax_sample_fn(op):
name = op.name
raise NotImplementedError(
f"No JAX implementation for the given distribution: {name}"
)
@jax_sample_fn.register(aer.BetaRV)
@jax_sample_fn.register(aer.DirichletRV)
@jax_sample_fn.register(aer.PoissonRV)
@jax_sample_fn.register(aer.MvNormalRV)
def jax_sample_fn_generic(op):
"""Generic JAX implementation of random variables."""
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.CauchyRV)
@jax_sample_fn.register(aer.LaplaceRV)
@jax_sample_fn.register(aer.LogisticRV)
@jax_sample_fn.register(aer.NormalRV)
def jax_sample_fn_loc_scale(op):
"""JAX implementation of random variables in the loc-scale families.
JAX only implements the standard version of random variables in the
loc-scale family. We thus need to translate and rescale the results
manually.
"""
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
loc, scale = parameters
sample = loc + jax_op(rng_key, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.BernoulliRV)
@jax_sample_fn.register(aer.CategoricalRV)
def jax_sample_fn_no_dtype(op):
"""Generic JAX implementation of random variables."""
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
sample = jax_op(rng_key, *parameters, shape=size)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.RandIntRV)
@jax_sample_fn.register(aer.UniformRV)
def jax_sample_fn_uniform(op):
"""JAX implementation of random variables with uniform density.
We need to pass the arguments as keyword arguments since the order
of arguments is not the same.
"""
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
minval, maxval = parameters
sample = jax_op(rng_key, shape=size, dtype=dtype, minval=minval, maxval=maxval)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.ParetoRV)
@jax_sample_fn.register(aer.GammaRV)
def jax_sample_fn_shape_rate(op):
"""JAX implementation of random variables in the shape-rate family.
JAX only implements the standard version of random variables in the
shape-rate family. We thus need to rescale the results manually.
"""
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
(shape, rate) = parameters
sample = jax_op(rng_key, shape, size, dtype) / rate
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.ExponentialRV)
def jax_sample_fn_exponential(op):
"""JAX implementation of `ExponentialRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
(scale,) = parameters
sample = jax.random.exponential(rng_key, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn
def random_variable(rng, size, dtype_num, *args):
if not op.inplace:
rng = rng.copy()
prng = rng["jax_state"]
data = getattr(jax.random, name)(key=prng, shape=size)
smpl_value = jnp.array(data, dtype=dtype)
rng["jax_state"] = jax.random.split(prng, num=1)[0]
@jax_sample_fn.register(aer.ChoiceRV)
def jax_funcify_choice(op):
"""JAX implementation of `ChoiceRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
(a, p, replace) = parameters
smpl_value = jax.random.choice(rng_key, a, size, replace, p)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, smpl_value)
return random_variable
return sample_fn
@jax_sample_fn.register(aer.PermutationRV)
def jax_sample_fn_permutation(op):
"""JAX implementation of `PermutationRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
(x,) = parameters
sample = jax.random.permutation(rng_key, x)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论