提交 90da9e61 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in JAX implementation of RandomVariables with implicit size

上级 17a5e424
...@@ -103,6 +103,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs): ...@@ -103,6 +103,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
assert_size_argument_jax_compatible(node) assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
# PyTensor uses empty size to represent size = None
if jax.numpy.asarray(size).shape == (0,):
size = None
return jax_sample_fn(op)(rng, size, out_dtype, *parameters) return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
else: else:
...@@ -161,6 +164,8 @@ def jax_sample_fn_loc_scale(op): ...@@ -161,6 +164,8 @@ def jax_sample_fn_loc_scale(op):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2) rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters loc, scale = parameters
if size is None:
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
sample = loc + jax_op(sampling_key, size, dtype) * scale sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
...@@ -184,15 +189,16 @@ def jax_sample_fn_bernoulli(op): ...@@ -184,15 +189,16 @@ def jax_sample_fn_bernoulli(op):
@jax_sample_fn.register(ptr.CategoricalRV) @jax_sample_fn.register(ptr.CategoricalRV)
def jax_sample_fn_no_dtype(op): def jax_sample_fn_categorical(op):
"""Generic JAX implementation of random variables.""" """JAX implementation of `CategoricalRV`."""
name = op.name
jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters): # We need a separate dispatch because Categorical expects logits in JAX
def sample_fn(rng, size, dtype, p):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2) rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax_op(sampling_key, *parameters, shape=size)
logits = jax.scipy.special.logit(p)
sample = jax.random.categorical(sampling_key, logits=logits, shape=size)
rng["jax_state"] = rng_key rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
...@@ -243,6 +249,8 @@ def jax_sample_fn_shape_scale(op): ...@@ -243,6 +249,8 @@ def jax_sample_fn_shape_scale(op):
def sample_fn(rng, size, dtype, shape, scale): def sample_fn(rng, size, dtype, shape, scale):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2) rng_key, sampling_key = jax.random.split(rng_key, 2)
if size is None:
size = jax.numpy.broadcast_arrays(shape, scale)[0].shape
sample = jax_op(sampling_key, shape, size, dtype) * scale sample = jax_op(sampling_key, shape, size, dtype) * scale
rng["jax_state"] = rng_key rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
...@@ -254,10 +262,11 @@ def jax_sample_fn_shape_scale(op): ...@@ -254,10 +262,11 @@ def jax_sample_fn_shape_scale(op):
def jax_sample_fn_exponential(op): def jax_sample_fn_exponential(op):
"""JAX implementation of `ExponentialRV`.""" """JAX implementation of `ExponentialRV`."""
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, scale):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2) rng_key, sampling_key = jax.random.split(rng_key, 2)
(scale,) = parameters if size is None:
size = jax.numpy.asarray(scale).shape
sample = jax.random.exponential(sampling_key, size, dtype) * scale sample = jax.random.exponential(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
...@@ -269,14 +278,11 @@ def jax_sample_fn_exponential(op): ...@@ -269,14 +278,11 @@ def jax_sample_fn_exponential(op):
def jax_sample_fn_t(op): def jax_sample_fn_t(op):
"""JAX implementation of `StudentTRV`.""" """JAX implementation of `StudentTRV`."""
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, df, loc, scale):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2) rng_key, sampling_key = jax.random.split(rng_key, 2)
( if size is None:
df, size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape
loc,
scale,
) = parameters
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
rng["jax_state"] = rng_key rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
......
...@@ -509,6 +509,34 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c ...@@ -509,6 +509,34 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
assert test_res.pvalue > 0.01 assert test_res.pvalue > 0.01
@pytest.mark.parametrize(
"rv_fn",
[
lambda param_that_implies_size: ptr.normal(
loc=0, scale=pt.exp(param_that_implies_size)
),
lambda param_that_implies_size: ptr.exponential(
scale=pt.exp(param_that_implies_size)
),
lambda param_that_implies_size: ptr.gamma(
shape=1, scale=pt.exp(param_that_implies_size)
),
lambda param_that_implies_size: ptr.t(
df=3, loc=param_that_implies_size, scale=1
),
],
)
def test_size_implied_by_broadcasted_parameters(rv_fn):
# We need a parameter with untyped shapes to test broadcasting does not result in identical draws
param_that_implies_size = pt.matrix("param_that_implies_size", shape=(None, None))
rv = rv_fn(param_that_implies_size)
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))}, mode=jax_mode)
assert draws.shape == (2, 2)
assert np.unique(draws).size == 4
@pytest.mark.parametrize("size", [(), (4,)]) @pytest.mark.parametrize("size", [(), (4,)])
def test_random_bernoulli(size): def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论