提交 90da9e61 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in JAX implementation of RandomVariables with implicit size

上级 17a5e424
......@@ -103,6 +103,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, dtype, *parameters):
# PyTensor uses empty size to represent size = None
if jax.numpy.asarray(size).shape == (0,):
size = None
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
else:
......@@ -161,6 +164,8 @@ def jax_sample_fn_loc_scale(op):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
if size is None:
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
......@@ -184,15 +189,16 @@ def jax_sample_fn_bernoulli(op):
@jax_sample_fn.register(ptr.CategoricalRV)
def jax_sample_fn_no_dtype(op):
"""Generic JAX implementation of random variables."""
name = op.name
jax_op = getattr(jax.random, name)
def jax_sample_fn_categorical(op):
"""JAX implementation of `CategoricalRV`."""
def sample_fn(rng, size, dtype, *parameters):
# We need a separate dispatch because Categorical expects logits in JAX
def sample_fn(rng, size, dtype, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax_op(sampling_key, *parameters, shape=size)
logits = jax.scipy.special.logit(p)
sample = jax.random.categorical(sampling_key, logits=logits, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
......@@ -243,6 +249,8 @@ def jax_sample_fn_shape_scale(op):
def sample_fn(rng, size, dtype, shape, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
if size is None:
size = jax.numpy.broadcast_arrays(shape, scale)[0].shape
sample = jax_op(sampling_key, shape, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
......@@ -254,10 +262,11 @@ def jax_sample_fn_shape_scale(op):
def jax_sample_fn_exponential(op):
"""JAX implementation of `ExponentialRV`."""
def sample_fn(rng, size, dtype, *parameters):
def sample_fn(rng, size, dtype, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(scale,) = parameters
if size is None:
size = jax.numpy.asarray(scale).shape
sample = jax.random.exponential(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
......@@ -269,14 +278,11 @@ def jax_sample_fn_exponential(op):
def jax_sample_fn_t(op):
"""JAX implementation of `StudentTRV`."""
def sample_fn(rng, size, dtype, *parameters):
def sample_fn(rng, size, dtype, df, loc, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(
df,
loc,
scale,
) = parameters
if size is None:
size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
......
......@@ -509,6 +509,34 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
assert test_res.pvalue > 0.01
@pytest.mark.parametrize(
"rv_fn",
[
lambda param_that_implies_size: ptr.normal(
loc=0, scale=pt.exp(param_that_implies_size)
),
lambda param_that_implies_size: ptr.exponential(
scale=pt.exp(param_that_implies_size)
),
lambda param_that_implies_size: ptr.gamma(
shape=1, scale=pt.exp(param_that_implies_size)
),
lambda param_that_implies_size: ptr.t(
df=3, loc=param_that_implies_size, scale=1
),
],
)
def test_size_implied_by_broadcasted_parameters(rv_fn):
# We need a parameter with untyped shapes to test broadcasting does not result in identical draws
param_that_implies_size = pt.matrix("param_that_implies_size", shape=(None, None))
rv = rv_fn(param_that_implies_size)
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))}, mode=jax_mode)
assert draws.shape == (2, 2)
assert np.unique(draws).size == 4
@pytest.mark.parametrize("size", [(), (4,)])
def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论