提交 9ab8df51 authored 作者: Etienne Duchesne's avatar Etienne Duchesne 提交者: Ricardo Vieira

Simplify dispatch of JAX random variables

上级 95ce102d
......@@ -105,14 +105,24 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
rng["jax_state"] = rng_key
sample = jax_sample_fn(op, node=node)(
sampling_key, size, out_dtype, *parameters
)
return (rng, sample)
else:
def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
rng["jax_state"] = rng_key
sample = jax_sample_fn(op, node=node)(
sampling_key, static_size, out_dtype, *parameters
)
return (rng, sample)
return sample_fn
......@@ -133,12 +143,9 @@ def jax_sample_fn_generic(op, node):
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
rng["jax_state"] = rng_key
return (rng, sample)
def sample_fn(rng_key, size, dtype, *parameters):
sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype)
return sample
return sample_fn
......@@ -159,29 +166,23 @@ def jax_sample_fn_loc_scale(op, node):
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, *parameters):
loc, scale = parameters
if size is None:
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
sample = loc + jax_op(rng_key, size, dtype) * scale
return sample
return sample_fn
@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_mvnormal(op, node):
def sample_fn(rng, size, dtype, mean, cov):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, mean, cov):
sample = jax.random.multivariate_normal(
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method
rng_key, mean, cov, shape=size, dtype=dtype, method=op.method
)
rng["jax_state"] = rng_key
return (rng, sample)
return sample
return sample_fn
......@@ -191,12 +192,9 @@ def jax_sample_fn_bernoulli(op, node):
"""JAX implementation of `BernoulliRV`."""
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
def sample_fn(rng, size, dtype, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax.random.bernoulli(sampling_key, p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
def sample_fn(rng_key, size, dtype, p):
sample = jax.random.bernoulli(rng_key, p, shape=size)
return sample
return sample_fn
......@@ -206,14 +204,10 @@ def jax_sample_fn_categorical(op, node):
"""JAX implementation of `CategoricalRV`."""
# We need a separate dispatch because Categorical expects logits in JAX
def sample_fn(rng, size, dtype, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, p):
logits = jax.scipy.special.logit(p)
sample = jax.random.categorical(sampling_key, logits=logits, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
sample = jax.random.categorical(rng_key, logits=logits, shape=size)
return sample
return sample_fn
......@@ -233,15 +227,10 @@ def jax_sample_fn_uniform(op, node):
name = "randint"
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, *parameters):
minval, maxval = parameters
sample = jax_op(
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
)
rng["jax_state"] = rng_key
return (rng, sample)
sample = jax_op(rng_key, shape=size, dtype=dtype, minval=minval, maxval=maxval)
return sample
return sample_fn
......@@ -258,14 +247,11 @@ def jax_sample_fn_shape_scale(op, node):
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, shape, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, shape, scale):
if size is None:
size = jax.numpy.broadcast_arrays(shape, scale)[0].shape
sample = jax_op(sampling_key, shape, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
sample = jax_op(rng_key, shape, size, dtype) * scale
return sample
return sample_fn
......@@ -274,14 +260,11 @@ def jax_sample_fn_shape_scale(op, node):
def jax_sample_fn_exponential(op, node):
"""JAX implementation of `ExponentialRV`."""
def sample_fn(rng, size, dtype, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, scale):
if size is None:
size = jax.numpy.asarray(scale).shape
sample = jax.random.exponential(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
sample = jax.random.exponential(rng_key, size, dtype) * scale
return sample
return sample_fn
......@@ -290,14 +273,11 @@ def jax_sample_fn_exponential(op, node):
def jax_sample_fn_t(op, node):
"""JAX implementation of `StudentTRV`."""
def sample_fn(rng, size, dtype, df, loc, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, df, loc, scale):
if size is None:
size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
return sample
return sample_fn
......@@ -315,10 +295,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"A default JAX rewrite should have materialized the implicit arange"
)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, *parameters):
if op.has_p_param:
a, p, core_shape = parameters
else:
......@@ -327,9 +304,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
core_shape = tuple(np.asarray(core_shape)[(0,) * batch_ndim])
if batch_ndim == 0:
sample = jax.random.choice(
sampling_key, a, shape=core_shape, replace=False, p=p
)
sample = jax.random.choice(rng_key, a, shape=core_shape, replace=False, p=p)
else:
if size is None:
......@@ -345,7 +320,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
if p is not None:
p = jax.numpy.broadcast_to(p, size + p.shape[batch_ndim:])
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
# Ravel the batch dimensions because vmap only works along a single axis
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:])
......@@ -366,8 +341,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
# Reshape the batch dimensions
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
rng["jax_state"] = rng_key
return (rng, sample)
return sample
return sample_fn
......@@ -378,9 +352,7 @@ def jax_sample_fn_permutation(op, node):
batch_ndim = op.batch_ndim(node)
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, *parameters):
(x,) = parameters
if batch_ndim:
# jax.random.permutation has no concept of batch dims
......@@ -389,17 +361,16 @@ def jax_sample_fn_permutation(op, node):
else:
x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:])
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
batch_sampling_keys, raveled_batch_x
)
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
else:
sample = jax.random.permutation(sampling_key, x)
sample = jax.random.permutation(rng_key, x)
rng["jax_state"] = rng_key
return (rng, sample)
return sample
return sample_fn
......@@ -414,15 +385,9 @@ def jax_sample_fn_binomial(op, node):
from numpyro.distributions.util import binomial
def sample_fn(rng, size, dtype, n, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = binomial(key=sampling_key, n=n, p=p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
def sample_fn(rng_key, size, dtype, n, p):
sample = binomial(key=rng_key, n=n, p=p, shape=size)
return sample
return sample_fn
......@@ -437,15 +402,9 @@ def jax_sample_fn_multinomial(op, node):
from numpyro.distributions.util import multinomial
def sample_fn(rng, size, dtype, n, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
def sample_fn(rng_key, size, dtype, n, p):
sample = multinomial(key=rng_key, n=n, p=p, shape=size)
return sample
return sample_fn
......@@ -460,17 +419,12 @@ def jax_sample_fn_vonmises(op, node):
from numpyro.distributions.util import von_mises_centered
def sample_fn(rng, size, dtype, mu, kappa):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
def sample_fn(rng_key, size, dtype, mu, kappa):
sample = von_mises_centered(
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
key=rng_key, concentration=kappa, shape=size, dtype=dtype
)
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi
rng["jax_state"] = rng_key
return (rng, sample)
return sample
return sample_fn
......@@ -796,7 +796,7 @@ def test_random_custom_implementation():
@jax_sample_fn.register(CustomRV)
def jax_sample_fn_custom(op, node):
def sample_fn(rng, size, dtype, *parameters):
return (rng, 0)
return 0
return sample_fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论