提交 383d4efe authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Thomas Wiecki

Add tests for JAX `RandomVariable` implementations

上级 d647578c
......@@ -207,3 +207,8 @@ def test_jax_checkandraise():
with pytest.warns(UserWarning):
function((p,), res, mode=jax_mode)
def set_test_value(x, v):
x.tag.test_value = v
return x
......@@ -2,61 +2,345 @@ import re
import numpy as np
import pytest
from packaging.version import parse as version_parse
import scipy.stats as stats
import pytensor
import pytensor.tensor as at
import pytensor.tensor.random as aer
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.random.basic import RandomVariable
from pytensor.tensor.random.utils import RandomStream
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
jax = pytest.importorskip("jax")
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.26"),
reason="JAX samplers require concrete/static shape values?",
)
def test_random_RandomStream():
"""Two successive calls of a compiled graph using `RandomStream` should
return different values.
"""
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()
with pytest.warns(
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = function([], out, mode=jax_mode)
jax_res_1 = fn()
jax_res_2 = fn()
assert not np.array_equal(jax_res_1, jax_res_2)
@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng))
def test_random_updates(rng_ctor):
original_value = rng_ctor(seed=98)
rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs
with pytest.warns(
UserWarning,
match=re.escape(
"The RandomType SharedVariables [original_rng] will not be used"
),
):
f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f()
# Check that original rng variable content was not overwritten when calling jax_typify
assert all(
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b)
for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__())
)
@pytest.mark.parametrize(
"at_dist, dist_params, rng, size",
"rv_op, dist_params, base_size, cdf_name, params_conv",
[
(
at.random.normal,
(),
shared(np.random.RandomState(123)),
10000,
aer.beta,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"beta",
lambda *args: args,
),
(
aer.cauchy,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"cauchy",
lambda *args: args,
),
(
aer.exponential,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
],
(2,),
"expon",
lambda *args: (0, args[0]),
),
(
aer.gamma,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"gamma",
lambda a, b: (a, 0.0, b),
),
(
aer.laplace,
[
set_test_value(at.dvector(), np.array([1.0, 2.0], dtype=np.float64)),
set_test_value(at.dscalar(), np.array(1.0, dtype=np.float64)),
],
(2,),
"laplace",
lambda *args: args,
),
(
aer.logistic,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"logistic",
lambda *args: args,
),
(
aer.normal,
[
set_test_value(
at.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"norm",
lambda *args: args,
),
(
aer.pareto,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
"pareto",
lambda *args: args,
),
(
aer.poisson,
[
set_test_value(
at.dvector(),
np.array([1000.0, 2000.0], dtype=np.float64),
),
],
(2,),
"poisson",
lambda *args: args,
),
(
at.random.normal,
aer.randint,
[
set_test_value(
at.lscalar(),
np.array(0, dtype=np.int64),
),
set_test_value( # high-value necessary since test on cdf
at.lscalar(),
np.array(1000, dtype=np.int64),
),
],
(),
shared(np.random.default_rng(123)),
10000,
"randint",
lambda *args: args,
),
(
aer.uniform,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1000.0, dtype=np.float64),
),
],
(2,),
"uniform",
lambda *args: args,
),
],
)
def test_random_stats(at_dist, dist_params, rng, size):
# The RNG states are not 1:1, so the best we can do is check some summary
# statistics of the samples
out = at.random.normal(*dist_params, rng=rng, size=size)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):
"""The JAX samplers are not one-to-one with NumPy samplers so we
need to use a statistical test to make sure that the transpilation
is correct.
Parameters
----------
rv_op
The transpiled `RandomVariable` `Op`.
dist_params
The parameters passed to the op.
def assert_fn(x, y):
(x,) = x
(y,) = y
assert x.dtype.kind == y.dtype.kind
"""
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
g_fn = function(dist_params, g, mode=jax_mode)
samples = g_fn(
*[
i.tag.test_value
for i in g_fn.maker.fgraph.inputs
if not isinstance(i, (SharedVariable, Constant))
]
)
d = 2 if config.floatX == "float64" else 1
np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d)
bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_params])
compare_jax_and_py(fgraph, [], assert_fn=assert_fn)
for idx in np.ndindex(*base_size):
cdf_params = params_conv(*tuple(arg[idx] for arg in bcast_dist_args))
test_res = stats.cramervonmises(
samples[(Ellipsis,) + idx], cdf_name, args=cdf_params
)
assert test_res.pvalue > 0.1
@pytest.mark.parametrize("size", [(), (4,)])
def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123))
g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
def test_random_mvnormal():
rng = shared(np.random.RandomState(123))
mu = np.ones(4)
cov = np.eye(4)
g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
@pytest.mark.parametrize(
"parameter, size",
[
(np.ones(4), ()),
(np.ones(4), (2, 4)),
],
)
def test_random_dirichlet(parameter, size):
rng = shared(np.random.RandomState(123))
g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
def test_random_choice():
# Elements are picked at equal frequency
num_samples = 10000
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(4), size=num_samples, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
# `replace=False` produces unique results
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
assert len(np.unique(samples)) == 99
# We can pass an array with probabilities
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples, np.zeros(10))
def test_random_categorical():
rng = shared(np.random.RandomState(123))
g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
def test_random_permutation():
array = np.arange(4)
rng = shared(np.random.RandomState(123))
g = at.random.permutation(array, rng=rng)
g_fn = function([], g, mode=jax_mode)
permuted = g_fn()
with pytest.raises(AssertionError):
np.testing.assert_allclose(array, permuted)
def test_random_unimplemented():
"""Compiling a graph with a non-supported `RandomVariable` should
raise an error.
"""
class NonExistentRV(RandomVariable):
name = "non-existent"
ndim_supp = 0
......@@ -78,38 +362,58 @@ def test_random_unimplemented():
compare_jax_and_py(fgraph, [])
def test_RandomStream():
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()
def test_random_custom_implementation():
"""We can register a JAX implementation for user-defined `RandomVariable`s"""
with pytest.warns(
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = function([], out, mode=jax_mode)
jax_res_1 = fn()
jax_res_2 = fn()
class CustomRV(RandomVariable):
name = "non-existent"
ndim_supp = 0
ndims_params = []
dtype = "floatX"
assert not np.array_equal(jax_res_1, jax_res_2)
def __call__(self, size=None, **kwargs):
return super().__call__(size=size, **kwargs)
def rng_fn(cls, rng, size):
return 0
@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng))
def test_random_updates(rng_ctor):
original_value = rng_ctor(seed=98)
rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs
from pytensor.link.jax.dispatch.random import jax_sample_fn
with pytest.warns(
UserWarning,
match=re.escape(
"The RandomType SharedVariables [original_rng] will not be used"
),
):
f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f()
@jax_sample_fn.register(CustomRV)
def jax_sample_fn_custom(op):
def sample_fn(rng, size, dtype, *parameters):
return (rng, 0)
# Check that original rng variable content was not overwritten when calling jax_typify
assert all(
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b)
for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__())
)
return sample_fn
nonexistentrv = CustomRV()
rng = shared(np.random.RandomState(123))
out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, [])
def test_random_concrete_shape():
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
There are three quantities that JAX considers as concrete:
1. Constants known at compile time;
2. The shape of an array.
3. `static_argnums` parameters
This test makes sure that graphs with `RandomVariable`s compile when the
`size` parameter satisfies either of these criteria.
"""
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
f = at.random.normal(0, 1, size=(3,), rng=rng)
g = at.random.normal(f, 1, size=x_at.shape, rng=rng)
g_fn = function([x_at], g, mode=jax_mode)
_ = g_fn(np.ones((2, 3)))
# This should compile, and `size_at` be passed to the list of `static_argnums`.
with pytest.raises(NotImplementedError):
size_at = at.scalar()
g = at.random.normal(f, 1, size=size_at, rng=rng)
g_fn = function([size_at], g, mode=jax_mode)
_ = g_fn(10)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论