提交 d80c0bf7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix Choice and Permutation not respecting the RandomVariable contract

These two RVs don't fall into the traditional RandomVariable contract because they don't have a concept of `batch_ndim`s. The hard-coded ndim params and ndim support were wrong and need to be defined for every node. * ChoiceRV was removed in favor of ChoiceWithoutReplacementRV which handles the cases without replacement. Those with replacement can be trivially be implemented with other basic RVs. * Both Permutation and ChoiceWithoutReplacement now support batch ndims * Avoid materializing the implicit arange
上级 e2e8757f
......@@ -93,8 +93,8 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
out_dtype = rv.type.dtype
out_size = rv.type.shape
if op.ndim_supp > 0:
out_size = node.outputs[1].type.shape[: -op.ndim_supp]
batch_ndim = op.batch_ndim(node)
out_size = node.default_output().type.shape[:batch_ndim]
# If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is
......@@ -106,18 +106,18 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
# PyTensor uses empty size to represent size = None
if jax.numpy.asarray(size).shape == (0,):
size = None
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
else:
def sample_fn(rng, size, dtype, *parameters):
return jax_sample_fn(op)(rng, out_size, out_dtype, *parameters)
return jax_sample_fn(op, node=node)(rng, out_size, out_dtype, *parameters)
return sample_fn
@singledispatch
def jax_sample_fn(op):
def jax_sample_fn(op, node):
name = op.name
raise NotImplementedError(
f"No JAX implementation for the given distribution: {name}"
......@@ -128,7 +128,7 @@ def jax_sample_fn(op):
@jax_sample_fn.register(ptr.DirichletRV)
@jax_sample_fn.register(ptr.PoissonRV)
@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_fn_generic(op):
def jax_sample_fn_generic(op, node):
"""Generic JAX implementation of random variables."""
name = op.name
jax_op = getattr(jax.random, name)
......@@ -149,7 +149,7 @@ def jax_sample_fn_generic(op):
@jax_sample_fn.register(ptr.LogisticRV)
@jax_sample_fn.register(ptr.NormalRV)
@jax_sample_fn.register(ptr.StandardNormalRV)
def jax_sample_fn_loc_scale(op):
def jax_sample_fn_loc_scale(op, node):
"""JAX implementation of random variables in the loc-scale families.
JAX only implements the standard version of random variables in the
......@@ -174,7 +174,7 @@ def jax_sample_fn_loc_scale(op):
@jax_sample_fn.register(ptr.BernoulliRV)
def jax_sample_fn_bernoulli(op):
def jax_sample_fn_bernoulli(op, node):
"""JAX implementation of `BernoulliRV`."""
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
......@@ -189,7 +189,7 @@ def jax_sample_fn_bernoulli(op):
@jax_sample_fn.register(ptr.CategoricalRV)
def jax_sample_fn_categorical(op):
def jax_sample_fn_categorical(op, node):
"""JAX implementation of `CategoricalRV`."""
# We need a separate dispatch because Categorical expects logits in JAX
......@@ -208,7 +208,7 @@ def jax_sample_fn_categorical(op):
@jax_sample_fn.register(ptr.RandIntRV)
@jax_sample_fn.register(ptr.IntegersRV)
@jax_sample_fn.register(ptr.UniformRV)
def jax_sample_fn_uniform(op):
def jax_sample_fn_uniform(op, node):
"""JAX implementation of random variables with uniform density.
We need to pass the arguments as keyword arguments since the order
......@@ -236,7 +236,7 @@ def jax_sample_fn_uniform(op):
@jax_sample_fn.register(ptr.ParetoRV)
@jax_sample_fn.register(ptr.GammaRV)
def jax_sample_fn_shape_scale(op):
def jax_sample_fn_shape_scale(op, node):
"""JAX implementation of random variables in the shape-scale family.
JAX only implements the standard version of random variables in the
......@@ -259,7 +259,7 @@ def jax_sample_fn_shape_scale(op):
@jax_sample_fn.register(ptr.ExponentialRV)
def jax_sample_fn_exponential(op):
def jax_sample_fn_exponential(op, node):
"""JAX implementation of `ExponentialRV`."""
def sample_fn(rng, size, dtype, scale):
......@@ -275,7 +275,7 @@ def jax_sample_fn_exponential(op):
@jax_sample_fn.register(ptr.StudentTRV)
def jax_sample_fn_t(op):
def jax_sample_fn_t(op, node):
"""JAX implementation of `StudentTRV`."""
def sample_fn(rng, size, dtype, df, loc, scale):
......@@ -290,30 +290,111 @@ def jax_sample_fn_t(op):
return sample_fn
@jax_sample_fn.register(ptr.ChoiceRV)
def jax_funcify_choice(op):
@jax_sample_fn.register(ptr.ChoiceWithoutReplacement)
def jax_funcify_choice(op, node):
"""JAX implementation of `ChoiceRV`."""
batch_ndim = op.batch_ndim(node)
a, *p, core_shape = node.inputs[3:]
a_core_ndim, *p_core_ndim, _ = op.ndims_params
if batch_ndim and a_core_ndim == 0:
raise NotImplementedError(
"Batch dimensions are not supported for 0d arrays. "
"A default JAX rewrite should have materialized the implicit arange"
)
a_batch_ndim = a.type.ndim - a_core_ndim
if op.has_p_param:
[p] = p
[p_core_ndim] = p_core_ndim
p_batch_ndim = p.type.ndim - p_core_ndim
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(a, p, replace) = parameters
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
if op.has_p_param:
a, p, core_shape = parameters
else:
a, core_shape = parameters
p = None
core_shape = tuple(np.asarray(core_shape))
if batch_ndim == 0:
sample = jax.random.choice(
sampling_key, a, shape=core_shape, replace=False, p=p
)
else:
if size is None:
if p is None:
size = a.shape[:a_batch_ndim]
else:
size = jax.numpy.broadcast_shapes(
a.shape[:a_batch_ndim],
p.shape[:p_batch_ndim],
)
a = jax.numpy.broadcast_to(a, size + a.shape[a_batch_ndim:])
if p is not None:
p = jax.numpy.broadcast_to(p, size + p.shape[p_batch_ndim:])
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
# Ravel the batch dimensions because vmap only works along a single axis
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:])
if p is None:
raveled_sample = jax.vmap(
lambda key, a: jax.random.choice(
key, a, shape=core_shape, replace=False, p=None
)
)(batch_sampling_keys, raveled_batch_a)
else:
raveled_batch_p = p.reshape((-1,) + p.shape[batch_ndim:])
raveled_sample = jax.vmap(
lambda key, a, p: jax.random.choice(
key, a, shape=core_shape, replace=False, p=p
)
)(batch_sampling_keys, raveled_batch_a, raveled_batch_p)
# Reshape the batch dimensions
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
rng["jax_state"] = rng_key
return (rng, smpl_value)
return (rng, sample)
return sample_fn
@jax_sample_fn.register(ptr.PermutationRV)
def jax_sample_fn_permutation(op):
def jax_sample_fn_permutation(op, node):
"""JAX implementation of `PermutationRV`."""
batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(x,) = parameters
sample = jax.random.permutation(sampling_key, x)
if batch_ndim:
# jax.random.permutation has no concept of batch dims
x_core_shape = x.shape[x_batch_ndim:]
if size is None:
size = x.shape[:x_batch_ndim]
else:
x = jax.numpy.broadcast_to(x, size + x_core_shape)
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
batch_sampling_keys, raveled_batch_x
)
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
else:
sample = jax.random.permutation(sampling_key, x)
rng["jax_state"] = rng_key
return (rng, sample)
......@@ -321,7 +402,7 @@ def jax_sample_fn_permutation(op):
@jax_sample_fn.register(ptr.BinomialRV)
def jax_sample_fn_binomial(op):
def jax_sample_fn_binomial(op, node):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
......@@ -344,7 +425,7 @@ def jax_sample_fn_binomial(op):
@jax_sample_fn.register(ptr.MultinomialRV)
def jax_sample_fn_multinomial(op):
def jax_sample_fn_multinomial(op, node):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
......@@ -367,7 +448,7 @@ def jax_sample_fn_multinomial(op):
@jax_sample_fn.register(ptr.VonMisesRV)
def jax_sample_fn_vonmises(op):
def jax_sample_fn_vonmises(op, node):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
......
......@@ -210,7 +210,6 @@ def {sized_fn_name}({random_fn_input_names}):
@numba_funcify.register(ptr.BinomialRV)
@numba_funcify.register(ptr.MultinomialRV)
@numba_funcify.register(ptr.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(ptr.ChoiceRV) # the `p` argument is not supported
@numba_funcify.register(ptr.PermutationRV)
def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
......@@ -367,3 +366,63 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
return (rng, np.random.dirichlet(alphas, size))
return dirichlet_rv
@numba_funcify.register(ptr.ChoiceWithoutReplacement)
def numba_funcify_choice_without_replacement(op, node, **kwargs):
batch_ndim = op.batch_ndim(node)
if batch_ndim:
# The code isn't too hard to write, but Numba doesn't support a with ndim > 1,
# and I don't want to change the batched tests for this
# We'll just raise an error for now
raise NotImplementedError(
"ChoiceWithoutReplacement with batch_ndim not supported in Numba backend"
)
[core_shape_len] = node.inputs[-1].type.shape
if op.has_p_param:
@numba_basic.numba_njit
def choice_without_replacement_rv(rng, size, dtype, a, p, core_shape):
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False, p=p)
return (rng, samples)
else:
@numba_basic.numba_njit
def choice_without_replacement_rv(rng, size, dtype, a, core_shape):
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False)
return (rng, samples)
return choice_without_replacement_rv
@numba_funcify.register(ptr.PermutationRV)
def numba_funcify_permutation(op, node, **kwargs):
# PyTensor uses size=() to represent size=None
size_is_none = node.inputs[1].type.shape == (0,)
batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
@numba_basic.numba_njit
def permutation_rv(rng, size, dtype, x):
if batch_ndim:
x_core_shape = x.shape[x_batch_ndim:]
if size_is_none:
size = x.shape[:batch_ndim]
else:
size = numba_ndarray.to_fixed_tuple(size, batch_ndim)
x = np.broadcast_to(x, size + x_core_shape)
samples = np.empty(size + x_core_shape, dtype=x.dtype)
for index in np.ndindex(size):
samples[index] = np.random.permutation(x[index])
else:
samples = np.random.permutation(x)
return (rng, samples)
return permutation_rv
from collections.abc import Sequence
from copy import copy
from typing import cast
import numpy as np
......@@ -306,6 +307,9 @@ class RandomVariable(Op):
return Apply(self, inputs, outputs)
def batch_ndim(self, node: Apply) -> int:
return cast(int, node.default_output().type.ndim - self.ndim_supp)
def perform(self, node, inputs, outputs):
rng_var_out, smpl_out = outputs
......
......@@ -3,10 +3,18 @@ from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.tensor import abs as abs_t
from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt
from pytensor.tensor.basic import MakeVector, cast, ones_like, switch, zeros_like
from pytensor.tensor.basic import (
MakeVector,
arange,
cast,
ones_like,
switch,
zeros_like,
)
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
BetaBinomialRV,
ChoiceWithoutReplacement,
GenGammaRV,
GeometricRV,
HalfNormalRV,
......@@ -137,6 +145,32 @@ def beta_binomial_from_beta_binomial(fgraph, node):
return [next_rng, b]
@node_rewriter([ChoiceWithoutReplacement])
def materialize_implicit_arange_choice_without_replacement(fgraph, node):
"""JAX random.choice does not support 0d arrays but when we have batch_ndim we need to vmap through batched `a`.
This rewrite materializes the implicit `a`
"""
op = node.op
if op.batch_ndim(node) == 0 or op.ndims_params[0] > 0:
# No need to materialize arange
return None
rng, size, dtype, a_scalar_param, *other_params = node.inputs
if a_scalar_param.type.ndim > 0:
# Automatic vectorization could have made this parameter batched,
# there is no nice way to materialize a batched arange
return None
a_vector_param = arange(a_scalar_param)
new_props_dict = op._props_dict().copy()
new_ndims_params = list(op.ndims_params)
new_ndims_params[0] += 1
new_props_dict["ndims_params"] = new_ndims_params
new_op = type(op)(**new_props_dict)
return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs
random_vars_opt = SequenceDB()
random_vars_opt.register(
"lognormal_from_normal",
......@@ -178,6 +212,11 @@ random_vars_opt.register(
in2out(beta_binomial_from_beta_binomial),
"jax",
)
random_vars_opt.register(
"materialize_implicit_arange_choice_without_replacement",
in2out(materialize_implicit_arange_choice_without_replacement),
"jax",
)
optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110)
optdb.register(
......
......@@ -13,6 +13,11 @@ from pytensor.tensor.random.basic import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.utils import RandomStream
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
from tests.tensor.random.test_basic import (
batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
)
jax = pytest.importorskip("jax")
......@@ -574,7 +579,7 @@ def test_random_dirichlet(parameter, size):
def test_random_choice():
# `replace=True` and `p is None`
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
g = pt.random.choice(np.arange(4), size=10_000, rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
......@@ -644,6 +649,19 @@ def test_random_permutation():
np.testing.assert_allclose(array, permuted)
@pytest.mark.parametrize(
"batch_dims_tester",
[
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
batched_permutation_tester,
],
)
def test_unnatural_batched_dims(batch_dims_tester):
"""Tests for RVs that don't have natural batch dims in JAX API."""
batch_dims_tester(mode="JAX")
def test_random_geometric():
rng = shared(np.random.RandomState(123))
p = np.array([0.3, 0.7])
......@@ -792,7 +810,7 @@ def test_random_custom_implementation():
from pytensor.link.jax.dispatch.random import jax_sample_fn
@jax_sample_fn.register(CustomRV)
def jax_sample_fn_custom(op):
def jax_sample_fn_custom(op, node):
def sample_fn(rng, size, dtype, *parameters):
return (rng, 0)
......
import contextlib
from functools import partial
import numpy as np
import pytest
......@@ -17,6 +18,11 @@ from tests.link.numba.test_basic import (
numba_mode,
set_test_value,
)
from tests.tensor.random.test_basic import (
batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
)
rng = np.random.default_rng(42849)
......@@ -281,6 +287,88 @@ rng = np.random.default_rng(42849)
pt.as_tensor(tuple(set_test_value(pt.lscalar(), v) for v in [4, 3, 2])),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(
ptr.permutation,
[
set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)),
],
(),
),
(
partial(ptr.choice, replace=True),
[
set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)),
],
pt.as_tensor([2]),
),
(
# p must be passed by kwarg
lambda a, p, size, rng: ptr.choice(
a, p=p, size=size, replace=True, rng=rng
),
[
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value(
pt.dvector(), np.array([0.5, 0.0, 0.5], dtype=np.float64)
),
],
(),
),
(
partial(ptr.choice, replace=False),
[
set_test_value(pt.dvector(), np.arange(5, dtype=np.float64)),
],
pt.as_tensor([2]),
),
pytest.param(
partial(ptr.choice, replace=False),
[
set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)),
],
pt.as_tensor([2]),
marks=pytest.mark.xfail(
raises=ValueError,
reason="Numba random.choice does not support >=1D `a`",
),
),
pytest.param(
# p must be passed by kwarg
lambda a, p, size, rng: ptr.choice(
a, p=p, size=size, replace=False, rng=rng
),
[
set_test_value(pt.vector(), np.arange(5, dtype=np.float64)),
# Boring p, because the variable is not truly "aligned"
set_test_value(
pt.dvector(),
np.array([0.5, 0.0, 0.25, 0.0, 0.25], dtype=np.float64),
),
],
(),
marks=pytest.mark.xfail(
raises=Exception, # numba.TypeError
reason="Numba random.choice does not support `p` parameter",
),
),
pytest.param(
# p must be passed by kwarg
lambda a, p, size, rng: ptr.choice(
a, p=p, size=size, replace=False, rng=rng
),
[
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
# Boring p, because the variable is not truly "aligned"
set_test_value(
pt.dvector(), np.array([0.5, 0.0, 0.5], dtype=np.float64)
),
],
(),
marks=pytest.mark.xfail(
raises=ValueError,
reason="Numba random.choice does not support >=1D `a`",
),
),
],
ids=str,
)
......@@ -595,3 +683,22 @@ def test_random_Generator():
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize(
"batch_dims_tester",
[
pytest.param(
batched_unweighted_choice_without_replacement_tester,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
pytest.param(
batched_weighted_choice_without_replacement_tester,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
batched_permutation_tester,
],
)
def test_unnatural_batched_dims(batch_dims_tester):
"""Tests for RVs that don't have natural batch dims in Numba API."""
batch_dims_tester(mode="NUMBA", rng_ctor=np.random.RandomState)
......@@ -18,6 +18,8 @@ from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import ones, stack
from pytensor.tensor.random.basic import (
ChoiceWithoutReplacement,
PermutationRV,
_gamma,
bernoulli,
beta,
......@@ -1394,9 +1396,6 @@ def test_integers_samples():
def test_choice_samples():
with pytest.raises(NotImplementedError):
choice._supp_shape_from_params(np.asarray(5))
compare_sample_values(choice, np.asarray(5))
compare_sample_values(choice, np.asarray([5]))
compare_sample_values(choice, np.array([1.0, 5.0], dtype=config.floatX))
......@@ -1423,19 +1422,6 @@ def test_choice_samples():
compare_sample_values(choice, pt.as_tensor_variable([1, 2, 3]), 2, replace=True)
def test_choice_infer_shape():
node = choice([0, 1]).owner
res = node.op._infer_shape((), node.inputs[3:], None)
assert tuple(res.eval()) == ()
node = choice([0, 1]).owner
# The param_shape of a NoneConst is None, during shape_inference
res = node.op._infer_shape(
(), node.inputs[3:], (node.inputs[3].shape, None, node.inputs[5].shape)
)
assert tuple(res.eval()) == ()
def test_permutation_samples():
compare_sample_values(
permutation,
......@@ -1455,11 +1441,195 @@ def test_permutation_shape():
assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5)
def batched_unweighted_choice_without_replacement_tester(
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
"""Test unweighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the
RandomVariable API.
It can be triggered by manual buiding the Op or during automatic vectorization.
"""
rng = shared(rng_ctor())
# Batched a implicit size
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, core_shape_len],
dtype="int64",
)
a = np.arange(3 * 5 * 2).reshape((3, 5, 2))
core_shape = (4,)
rv = rv_op(a, core_shape, rng=rng)
assert rv.type.shape == (3, 4, 2)
draws = rv.eval(mode=mode)
for i in range(3):
draw = draws[i]
assert np.unique(draw).size == 8
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
# Explicit size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 2
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, len(core_shape)],
dtype="int64",
)
core_shape = (4, 1)
rv = rv_op(a, core_shape, size=(2, 3), rng=rng)
assert rv.type.shape == (2, 3, 4, 1, 2)
draws = rv.eval(mode=mode)
for j in range(2):
for i in range(3):
draw = draws[j, i]
assert np.unique(draw).size == 8
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
def batched_weighted_choice_without_replacement_tester(
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
"""Test weighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the
RandomVariable API.
It can be triggered by manual buiding the Op or during automatic vectorization.
"""
rng = shared(rng_ctor())
# 3 ndims params indicates p is passed
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, 1, 1],
dtype="int64",
)
# Batched a implicit size
a = np.arange(4 * 5 * 2).reshape((4, 5, 2))
p = np.array([0.0, 0.25, 0.25, 0.25, 0.25])
core_shape = (3,)
rv = rv_op(a, p, core_shape, rng=rng)
assert rv.type.shape == (4, 3, 2)
draws = rv.eval(mode=mode)
for i in range(4):
draw = draws[i].ravel()
assert np.unique(draw).size == 6
# The first two entries after each step of 10 have zero probability
assert np.all((draw >= i * 10 + 2) & (draw < (i + 1) * 10))
# p and a are batched
# Test implicit arange
a_core_ndim = 0
core_shape_len = 2
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, 1, 1],
dtype="int64",
)
a = 6
p = np.array(
[
# Only even numbers allowed
[1 / 3, 0.0, 1 / 3, 0.0, 1 / 3, 0.0],
# Only odd numbers allowed
[0.0, 1 / 3, 0.0, 1 / 3, 0.0, 1 / 3],
]
)
core_shape = (3, 1)
rv = rv_op(a, p, core_shape, rng=rng)
assert rv.type.shape == (2, 3, 1)
draws = rv.eval(mode=mode)
for i in range(2):
draw = np.asarray(draws[i].ravel())
assert set(draw) == set(range(i, 6, 2))
# Size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, 1, 1],
dtype="int64",
)
a = np.arange(4 * 5 * 2).reshape((4, 5, 2))
p = np.array([0.0, 0.25, 0.25, 0.25, 0.25])
core_shape = (3,)
rv = rv_op(a, p, core_shape, size=(5, 1, 4))
assert rv.type.shape == (5, 1, 4, 3, 2)
draws = rv.eval(mode=mode)
for j in range(5):
for i in range(4):
draw = draws[j, 0, i].ravel()
assert np.unique(draw).size == 6
# The first two entries after each step of 10 have zero probability
assert np.all((draw >= i * 10 + 2) & (draw < (i + 1) * 10))
def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng):
"""Test permutation with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the
RandomVariable API.
It can be triggered by manual buiding the Op or during automatic vectorization.
"""
rng = shared(rng_ctor())
rv_op = PermutationRV(ndim_supp=2, ndims_params=[2], dtype="int64")
x = np.arange(5 * 3 * 2).reshape((5, 3, 2))
# Batched x and implicit size
rv = rv_op(x, rng=rng)
assert rv.type.shape == (5, 3, 2)
draws = rv.eval(mode=mode)
for i in range(5):
assert set(np.asarray(draws[i].ravel())) == set(range(i * 6, (i + 1) * 6))
# Size broadcasts beyond x
rv = rv_op(x, size=(4, 5), rng=rng)
assert rv.type.shape == (4, 5, 3, 2)
draws = rv.eval(mode=mode)
for j in range(4):
for i in range(5):
assert set(np.asarray(draws[j, i].ravel())) == set(
range(i * 6, (i + 1) * 6)
)
@pytest.mark.parametrize(
"batch_dims_tester",
[
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
batched_permutation_tester,
],
)
def test_unnatural_batched_dims(batch_dims_tester):
"Tests for RVs that don't have natural batch dims in Numpy API."
batch_dims_tester()
@config.change_flags(compute_test_value="off")
def test_pickle():
# This is an interesting `Op` case, because it has `None` types and a
# conditional dtype
sample_a = choice(5, size=(2, 3))
sample_a = choice(5, replace=False, size=(2, 3))
a_pkl = pickle.dumps(sample_a)
a_unpkl = pickle.loads(a_pkl)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论