提交 9ab8df51 authored 作者: Etienne Duchesne's avatar Etienne Duchesne 提交者: Ricardo Vieira

Simplify dispatch of JAX random variables

上级 95ce102d
...@@ -105,14 +105,24 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): ...@@ -105,14 +105,24 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
assert_size_argument_jax_compatible(node) assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, *parameters): def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters) rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
rng["jax_state"] = rng_key
sample = jax_sample_fn(op, node=node)(
sampling_key, size, out_dtype, *parameters
)
return (rng, sample)
else: else:
def sample_fn(rng, size, *parameters): def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)( rng_key = rng["jax_state"]
rng, static_size, out_dtype, *parameters rng_key, sampling_key = jax.random.split(rng_key, 2)
rng["jax_state"] = rng_key
sample = jax_sample_fn(op, node=node)(
sampling_key, static_size, out_dtype, *parameters
) )
return (rng, sample)
return sample_fn return sample_fn
...@@ -133,12 +143,9 @@ def jax_sample_fn_generic(op, node): ...@@ -133,12 +143,9 @@ def jax_sample_fn_generic(op, node):
name = op.name name = op.name
jax_op = getattr(jax.random, name) jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng_key, size, dtype, *parameters):
rng_key = rng["jax_state"] sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype)
rng_key, sampling_key = jax.random.split(rng_key, 2) return sample
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn return sample_fn
...@@ -159,29 +166,23 @@ def jax_sample_fn_loc_scale(op, node): ...@@ -159,29 +166,23 @@ def jax_sample_fn_loc_scale(op, node):
name = op.name name = op.name
jax_op = getattr(jax.random, name) jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng_key, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters loc, scale = parameters
if size is None: if size is None:
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
sample = loc + jax_op(sampling_key, size, dtype) * scale sample = loc + jax_op(rng_key, size, dtype) * scale
rng["jax_state"] = rng_key return sample
return (rng, sample)
return sample_fn return sample_fn
@jax_sample_fn.register(ptr.MvNormalRV) @jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_mvnormal(op, node): def jax_sample_mvnormal(op, node):
def sample_fn(rng, size, dtype, mean, cov): def sample_fn(rng_key, size, dtype, mean, cov):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax.random.multivariate_normal( sample = jax.random.multivariate_normal(
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method rng_key, mean, cov, shape=size, dtype=dtype, method=op.method
) )
rng["jax_state"] = rng_key return sample
return (rng, sample)
return sample_fn return sample_fn
...@@ -191,12 +192,9 @@ def jax_sample_fn_bernoulli(op, node): ...@@ -191,12 +192,9 @@ def jax_sample_fn_bernoulli(op, node):
"""JAX implementation of `BernoulliRV`.""" """JAX implementation of `BernoulliRV`."""
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX # We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
def sample_fn(rng, size, dtype, p): def sample_fn(rng_key, size, dtype, p):
rng_key = rng["jax_state"] sample = jax.random.bernoulli(rng_key, p, shape=size)
rng_key, sampling_key = jax.random.split(rng_key, 2) return sample
sample = jax.random.bernoulli(sampling_key, p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn return sample_fn
...@@ -206,14 +204,10 @@ def jax_sample_fn_categorical(op, node): ...@@ -206,14 +204,10 @@ def jax_sample_fn_categorical(op, node):
"""JAX implementation of `CategoricalRV`.""" """JAX implementation of `CategoricalRV`."""
# We need a separate dispatch because Categorical expects logits in JAX # We need a separate dispatch because Categorical expects logits in JAX
def sample_fn(rng, size, dtype, p): def sample_fn(rng_key, size, dtype, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
logits = jax.scipy.special.logit(p) logits = jax.scipy.special.logit(p)
sample = jax.random.categorical(sampling_key, logits=logits, shape=size) sample = jax.random.categorical(rng_key, logits=logits, shape=size)
rng["jax_state"] = rng_key return sample
return (rng, sample)
return sample_fn return sample_fn
...@@ -233,15 +227,10 @@ def jax_sample_fn_uniform(op, node): ...@@ -233,15 +227,10 @@ def jax_sample_fn_uniform(op, node):
name = "randint" name = "randint"
jax_op = getattr(jax.random, name) jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng_key, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
minval, maxval = parameters minval, maxval = parameters
sample = jax_op( sample = jax_op(rng_key, shape=size, dtype=dtype, minval=minval, maxval=maxval)
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval return sample
)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn return sample_fn
...@@ -258,14 +247,11 @@ def jax_sample_fn_shape_scale(op, node): ...@@ -258,14 +247,11 @@ def jax_sample_fn_shape_scale(op, node):
name = op.name name = op.name
jax_op = getattr(jax.random, name) jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, shape, scale): def sample_fn(rng_key, size, dtype, shape, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
if size is None: if size is None:
size = jax.numpy.broadcast_arrays(shape, scale)[0].shape size = jax.numpy.broadcast_arrays(shape, scale)[0].shape
sample = jax_op(sampling_key, shape, size, dtype) * scale sample = jax_op(rng_key, shape, size, dtype) * scale
rng["jax_state"] = rng_key return sample
return (rng, sample)
return sample_fn return sample_fn
...@@ -274,14 +260,11 @@ def jax_sample_fn_shape_scale(op, node): ...@@ -274,14 +260,11 @@ def jax_sample_fn_shape_scale(op, node):
def jax_sample_fn_exponential(op, node): def jax_sample_fn_exponential(op, node):
"""JAX implementation of `ExponentialRV`.""" """JAX implementation of `ExponentialRV`."""
def sample_fn(rng, size, dtype, scale): def sample_fn(rng_key, size, dtype, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
if size is None: if size is None:
size = jax.numpy.asarray(scale).shape size = jax.numpy.asarray(scale).shape
sample = jax.random.exponential(sampling_key, size, dtype) * scale sample = jax.random.exponential(rng_key, size, dtype) * scale
rng["jax_state"] = rng_key return sample
return (rng, sample)
return sample_fn return sample_fn
...@@ -290,14 +273,11 @@ def jax_sample_fn_exponential(op, node): ...@@ -290,14 +273,11 @@ def jax_sample_fn_exponential(op, node):
def jax_sample_fn_t(op, node): def jax_sample_fn_t(op, node):
"""JAX implementation of `StudentTRV`.""" """JAX implementation of `StudentTRV`."""
def sample_fn(rng, size, dtype, df, loc, scale): def sample_fn(rng_key, size, dtype, df, loc, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
if size is None: if size is None:
size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
rng["jax_state"] = rng_key return sample
return (rng, sample)
return sample_fn return sample_fn
...@@ -315,10 +295,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): ...@@ -315,10 +295,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"A default JAX rewrite should have materialized the implicit arange" "A default JAX rewrite should have materialized the implicit arange"
) )
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng_key, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
if op.has_p_param: if op.has_p_param:
a, p, core_shape = parameters a, p, core_shape = parameters
else: else:
...@@ -327,9 +304,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): ...@@ -327,9 +304,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
core_shape = tuple(np.asarray(core_shape)[(0,) * batch_ndim]) core_shape = tuple(np.asarray(core_shape)[(0,) * batch_ndim])
if batch_ndim == 0: if batch_ndim == 0:
sample = jax.random.choice( sample = jax.random.choice(rng_key, a, shape=core_shape, replace=False, p=p)
sampling_key, a, shape=core_shape, replace=False, p=p
)
else: else:
if size is None: if size is None:
...@@ -345,7 +320,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): ...@@ -345,7 +320,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
if p is not None: if p is not None:
p = jax.numpy.broadcast_to(p, size + p.shape[batch_ndim:]) p = jax.numpy.broadcast_to(p, size + p.shape[batch_ndim:])
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size)) batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
# Ravel the batch dimensions because vmap only works along a single axis # Ravel the batch dimensions because vmap only works along a single axis
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:]) raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:])
...@@ -366,8 +341,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): ...@@ -366,8 +341,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
# Reshape the batch dimensions # Reshape the batch dimensions
sample = raveled_sample.reshape(size + raveled_sample.shape[1:]) sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
rng["jax_state"] = rng_key return sample
return (rng, sample)
return sample_fn return sample_fn
...@@ -378,9 +352,7 @@ def jax_sample_fn_permutation(op, node): ...@@ -378,9 +352,7 @@ def jax_sample_fn_permutation(op, node):
batch_ndim = op.batch_ndim(node) batch_ndim = op.batch_ndim(node)
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng_key, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(x,) = parameters (x,) = parameters
if batch_ndim: if batch_ndim:
# jax.random.permutation has no concept of batch dims # jax.random.permutation has no concept of batch dims
...@@ -389,17 +361,16 @@ def jax_sample_fn_permutation(op, node): ...@@ -389,17 +361,16 @@ def jax_sample_fn_permutation(op, node):
else: else:
x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:]) x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:])
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size)) batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:]) raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))( raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
batch_sampling_keys, raveled_batch_x batch_sampling_keys, raveled_batch_x
) )
sample = raveled_sample.reshape(size + raveled_sample.shape[1:]) sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
else: else:
sample = jax.random.permutation(sampling_key, x) sample = jax.random.permutation(rng_key, x)
rng["jax_state"] = rng_key return sample
return (rng, sample)
return sample_fn return sample_fn
...@@ -414,15 +385,9 @@ def jax_sample_fn_binomial(op, node): ...@@ -414,15 +385,9 @@ def jax_sample_fn_binomial(op, node):
from numpyro.distributions.util import binomial from numpyro.distributions.util import binomial
def sample_fn(rng, size, dtype, n, p): def sample_fn(rng_key, size, dtype, n, p):
rng_key = rng["jax_state"] sample = binomial(key=rng_key, n=n, p=p, shape=size)
rng_key, sampling_key = jax.random.split(rng_key, 2) return sample
sample = binomial(key=sampling_key, n=n, p=p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn return sample_fn
...@@ -437,15 +402,9 @@ def jax_sample_fn_multinomial(op, node): ...@@ -437,15 +402,9 @@ def jax_sample_fn_multinomial(op, node):
from numpyro.distributions.util import multinomial from numpyro.distributions.util import multinomial
def sample_fn(rng, size, dtype, n, p): def sample_fn(rng_key, size, dtype, n, p):
rng_key = rng["jax_state"] sample = multinomial(key=rng_key, n=n, p=p, shape=size)
rng_key, sampling_key = jax.random.split(rng_key, 2) return sample
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn return sample_fn
...@@ -460,17 +419,12 @@ def jax_sample_fn_vonmises(op, node): ...@@ -460,17 +419,12 @@ def jax_sample_fn_vonmises(op, node):
from numpyro.distributions.util import von_mises_centered from numpyro.distributions.util import von_mises_centered
def sample_fn(rng, size, dtype, mu, kappa): def sample_fn(rng_key, size, dtype, mu, kappa):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = von_mises_centered( sample = von_mises_centered(
key=sampling_key, concentration=kappa, shape=size, dtype=dtype key=rng_key, concentration=kappa, shape=size, dtype=dtype
) )
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi
rng["jax_state"] = rng_key return sample
return (rng, sample)
return sample_fn return sample_fn
...@@ -796,7 +796,7 @@ def test_random_custom_implementation(): ...@@ -796,7 +796,7 @@ def test_random_custom_implementation():
@jax_sample_fn.register(CustomRV) @jax_sample_fn.register(CustomRV)
def jax_sample_fn_custom(op, node): def jax_sample_fn_custom(op, node):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
return (rng, 0) return 0
return sample_fn return sample_fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论