提交 383d4efe authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Thomas Wiecki

Add tests for JAX `RandomVariable` implementations

上级 d647578c
...@@ -207,3 +207,8 @@ def test_jax_checkandraise(): ...@@ -207,3 +207,8 @@ def test_jax_checkandraise():
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
function((p,), res, mode=jax_mode) function((p,), res, mode=jax_mode)
def set_test_value(x, v):
x.tag.test_value = v
return x
...@@ -2,61 +2,345 @@ import re ...@@ -2,61 +2,345 @@ import re
import numpy as np import numpy as np
import pytest import pytest
from packaging.version import parse as version_parse import scipy.stats as stats
import pytensor import pytensor
import pytensor.tensor as at import pytensor.tensor as at
import pytensor.tensor.random as aer
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.sharedvalue import shared from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.random.basic import RandomVariable from pytensor.tensor.random.basic import RandomVariable
from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.random.utils import RandomStream
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
@pytest.mark.xfail( def test_random_RandomStream():
version_parse(jax.__version__) >= version_parse("0.2.26"), """Two successive calls of a compiled graph using `RandomStream` should
reason="JAX samplers require concrete/static shape values?", return different values.
)
"""
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()
with pytest.warns(
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = function([], out, mode=jax_mode)
jax_res_1 = fn()
jax_res_2 = fn()
assert not np.array_equal(jax_res_1, jax_res_2)
@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng))
def test_random_updates(rng_ctor):
original_value = rng_ctor(seed=98)
rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs
with pytest.warns(
UserWarning,
match=re.escape(
"The RandomType SharedVariables [original_rng] will not be used"
),
):
f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f()
# Check that original rng variable content was not overwritten when calling jax_typify
assert all(
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b)
for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__())
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"at_dist, dist_params, rng, size", "rv_op, dist_params, base_size, cdf_name, params_conv",
[ [
( (
at.random.normal, aer.beta,
(), [
shared(np.random.RandomState(123)), set_test_value(
10000, at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"beta",
lambda *args: args,
),
(
aer.cauchy,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"cauchy",
lambda *args: args,
),
(
aer.exponential,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
],
(2,),
"expon",
lambda *args: (0, args[0]),
),
(
aer.gamma,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"gamma",
lambda a, b: (a, 0.0, b),
),
(
aer.laplace,
[
set_test_value(at.dvector(), np.array([1.0, 2.0], dtype=np.float64)),
set_test_value(at.dscalar(), np.array(1.0, dtype=np.float64)),
],
(2,),
"laplace",
lambda *args: args,
),
(
aer.logistic,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"logistic",
lambda *args: args,
),
(
aer.normal,
[
set_test_value(
at.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"norm",
lambda *args: args,
),
(
aer.pareto,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
"pareto",
lambda *args: args,
),
(
aer.poisson,
[
set_test_value(
at.dvector(),
np.array([1000.0, 2000.0], dtype=np.float64),
),
],
(2,),
"poisson",
lambda *args: args,
), ),
( (
at.random.normal, aer.randint,
[
set_test_value(
at.lscalar(),
np.array(0, dtype=np.int64),
),
set_test_value( # high-value necessary since test on cdf
at.lscalar(),
np.array(1000, dtype=np.int64),
),
],
(), (),
shared(np.random.default_rng(123)), "randint",
10000, lambda *args: args,
),
(
aer.uniform,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1000.0, dtype=np.float64),
),
],
(2,),
"uniform",
lambda *args: args,
), ),
], ],
) )
def test_random_stats(at_dist, dist_params, rng, size): def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):
# The RNG states are not 1:1, so the best we can do is check some summary """The JAX samplers are not one-to-one with NumPy samplers so we
# statistics of the samples need to use a statistical test to make sure that the transpilation
out = at.random.normal(*dist_params, rng=rng, size=size) is correct.
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
Parameters
----------
rv_op
The transpiled `RandomVariable` `Op`.
dist_params
The parameters passed to the op.
def assert_fn(x, y): """
(x,) = x rng = shared(np.random.RandomState(29402))
(y,) = y g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
assert x.dtype.kind == y.dtype.kind g_fn = function(dist_params, g, mode=jax_mode)
samples = g_fn(
*[
i.tag.test_value
for i in g_fn.maker.fgraph.inputs
if not isinstance(i, (SharedVariable, Constant))
]
)
d = 2 if config.floatX == "float64" else 1 bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_params])
np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d)
compare_jax_and_py(fgraph, [], assert_fn=assert_fn) for idx in np.ndindex(*base_size):
cdf_params = params_conv(*tuple(arg[idx] for arg in bcast_dist_args))
test_res = stats.cramervonmises(
samples[(Ellipsis,) + idx], cdf_name, args=cdf_params
)
assert test_res.pvalue > 0.1
@pytest.mark.parametrize("size", [(), (4,)])
def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123))
g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
def test_random_mvnormal():
rng = shared(np.random.RandomState(123))
mu = np.ones(4)
cov = np.eye(4)
g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
@pytest.mark.parametrize(
"parameter, size",
[
(np.ones(4), ()),
(np.ones(4), (2, 4)),
],
)
def test_random_dirichlet(parameter, size):
rng = shared(np.random.RandomState(123))
g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
def test_random_choice():
# Elements are picked at equal frequency
num_samples = 10000
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(4), size=num_samples, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
# `replace=False` produces unique results
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
assert len(np.unique(samples)) == 99
# We can pass an array with probabilities
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples, np.zeros(10))
def test_random_categorical():
rng = shared(np.random.RandomState(123))
g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
def test_random_permutation():
array = np.arange(4)
rng = shared(np.random.RandomState(123))
g = at.random.permutation(array, rng=rng)
g_fn = function([], g, mode=jax_mode)
permuted = g_fn()
with pytest.raises(AssertionError):
np.testing.assert_allclose(array, permuted)
def test_random_unimplemented(): def test_random_unimplemented():
"""Compiling a graph with a non-supported `RandomVariable` should
raise an error.
"""
class NonExistentRV(RandomVariable): class NonExistentRV(RandomVariable):
name = "non-existent" name = "non-existent"
ndim_supp = 0 ndim_supp = 0
...@@ -78,38 +362,58 @@ def test_random_unimplemented(): ...@@ -78,38 +362,58 @@ def test_random_unimplemented():
compare_jax_and_py(fgraph, []) compare_jax_and_py(fgraph, [])
def test_RandomStream(): def test_random_custom_implementation():
srng = RandomStream(seed=123) """We can register a JAX implementation for user-defined `RandomVariable`s"""
out = srng.normal() - srng.normal()
with pytest.warns( class CustomRV(RandomVariable):
UserWarning, name = "non-existent"
match=r"The RandomType SharedVariables \[.+\] will not be used", ndim_supp = 0
): ndims_params = []
fn = function([], out, mode=jax_mode) dtype = "floatX"
jax_res_1 = fn()
jax_res_2 = fn()
assert not np.array_equal(jax_res_1, jax_res_2) def __call__(self, size=None, **kwargs):
return super().__call__(size=size, **kwargs)
def rng_fn(cls, rng, size):
return 0
@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng)) from pytensor.link.jax.dispatch.random import jax_sample_fn
def test_random_updates(rng_ctor):
original_value = rng_ctor(seed=98)
rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs
with pytest.warns( @jax_sample_fn.register(CustomRV)
UserWarning, def jax_sample_fn_custom(op):
match=re.escape( def sample_fn(rng, size, dtype, *parameters):
"The RandomType SharedVariables [original_rng] will not be used" return (rng, 0)
),
):
f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f()
# Check that original rng variable content was not overwritten when calling jax_typify return sample_fn
assert all(
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) nonexistentrv = CustomRV()
for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__()) rng = shared(np.random.RandomState(123))
) out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, [])
def test_random_concrete_shape():
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
There are three quantities that JAX considers as concrete:
1. Constants known at compile time;
2. The shape of an array.
3. `static_argnums` parameters
This test makes sure that graphs with `RandomVariable`s compile when the
`size` parameter satisfies either of these criteria.
"""
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
f = at.random.normal(0, 1, size=(3,), rng=rng)
g = at.random.normal(f, 1, size=x_at.shape, rng=rng)
g_fn = function([x_at], g, mode=jax_mode)
_ = g_fn(np.ones((2, 3)))
# This should compile, and `size_at` be passed to the list of `static_argnums`.
with pytest.raises(NotImplementedError):
size_at = at.scalar()
g = at.random.normal(f, 1, size=size_at, rng=rng)
g_fn = function([size_at], g, mode=jax_mode)
_ = g_fn(10)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论