提交 53b00ea6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Intercept UserWarning on JAX random function tests

上级 93bfa1bd
import re
import numpy as np import numpy as np
import pytest import pytest
import scipy.stats as stats import scipy.stats as stats
...@@ -22,6 +20,13 @@ jax = pytest.importorskip("jax") ...@@ -22,6 +20,13 @@ jax = pytest.importorskip("jax")
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
def random_function(*args, **kwargs):
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
return function(*args, **kwargs)
def test_random_RandomStream(): def test_random_RandomStream():
"""Two successive calls of a compiled graph using `RandomStream` should """Two successive calls of a compiled graph using `RandomStream` should
return different values. return different values.
...@@ -30,11 +35,7 @@ def test_random_RandomStream(): ...@@ -30,11 +35,7 @@ def test_random_RandomStream():
srng = RandomStream(seed=123) srng = RandomStream(seed=123)
out = srng.normal() - srng.normal() out = srng.normal() - srng.normal()
with pytest.warns( fn = random_function([], out, mode=jax_mode)
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = function([], out, mode=jax_mode)
jax_res_1 = fn() jax_res_1 = fn()
jax_res_2 = fn() jax_res_2 = fn()
...@@ -47,13 +48,7 @@ def test_random_updates(rng_ctor): ...@@ -47,13 +48,7 @@ def test_random_updates(rng_ctor):
rng = shared(original_value, name="original_rng", borrow=False) rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs
with pytest.warns( f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
UserWarning,
match=re.escape(
"The RandomType SharedVariables [original_rng] will not be used"
),
):
f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f() assert f() != f()
# Check that original rng variable content was not overwritten when calling jax_typify # Check that original rng variable content was not overwritten when calling jax_typify
...@@ -83,17 +78,14 @@ def test_random_updates_input_storage_order(): ...@@ -83,17 +78,14 @@ def test_random_updates_input_storage_order():
# This function replaces inp by input_shared in the update expression # This function replaces inp by input_shared in the update expression
# This is what caused the RNG to appear later than inp_shared in the input_storage # This is what caused the RNG to appear later than inp_shared in the input_storage
with pytest.warns(
UserWarning, fn = random_function(
match=r"The RandomType SharedVariables \[.+\] will not be used", inputs=[],
): outputs=[],
fn = pytensor.function( updates={inp_shared: inp_update},
inputs=[], givens={inp: inp_shared},
outputs=[], mode="JAX",
updates={inp_shared: inp_update}, )
givens={inp: inp_shared},
mode="JAX",
)
fn() fn()
np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3) np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3)
fn() fn()
...@@ -457,7 +449,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c ...@@ -457,7 +449,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
else: else:
rng = shared(np.random.RandomState(29402)) rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng) g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
g_fn = function(dist_params, g, mode=jax_mode) g_fn = random_function(dist_params, g, mode=jax_mode)
samples = g_fn( samples = g_fn(
*[ *[
i.tag.test_value i.tag.test_value
...@@ -481,7 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c ...@@ -481,7 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
def test_random_bernoulli(size): def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng) g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
...@@ -492,7 +484,7 @@ def test_random_mvnormal(): ...@@ -492,7 +484,7 @@ def test_random_mvnormal():
mu = np.ones(4) mu = np.ones(4)
cov = np.eye(4) cov = np.eye(4)
g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng) g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
...@@ -507,7 +499,7 @@ def test_random_mvnormal(): ...@@ -507,7 +499,7 @@ def test_random_mvnormal():
def test_random_dirichlet(parameter, size): def test_random_dirichlet(parameter, size):
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng) g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
...@@ -517,21 +509,21 @@ def test_random_choice(): ...@@ -517,21 +509,21 @@ def test_random_choice():
num_samples = 10000 num_samples = 10000
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(4), size=num_samples, rng=rng) g = at.random.choice(np.arange(4), size=num_samples, rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2) np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
# `replace=False` produces unique results # `replace=False` produces unique results
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng) g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
assert len(np.unique(samples)) == 99 assert len(np.unique(samples)) == 99
# We can pass an array with probabilities # We can pass an array with probabilities
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng) g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples, np.zeros(10)) np.testing.assert_allclose(samples, np.zeros(10))
...@@ -539,7 +531,7 @@ def test_random_choice(): ...@@ -539,7 +531,7 @@ def test_random_choice():
def test_random_categorical(): def test_random_categorical():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng) g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1) np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
...@@ -548,7 +540,7 @@ def test_random_permutation(): ...@@ -548,7 +540,7 @@ def test_random_permutation():
array = np.arange(4) array = np.arange(4)
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = at.random.permutation(array, rng=rng) g = at.random.permutation(array, rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
permuted = g_fn() permuted = g_fn()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
np.testing.assert_allclose(array, permuted) np.testing.assert_allclose(array, permuted)
...@@ -558,7 +550,7 @@ def test_random_geometric(): ...@@ -558,7 +550,7 @@ def test_random_geometric():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
p = np.array([0.3, 0.7]) p = np.array([0.3, 0.7])
g = at.random.geometric(p, size=(10_000, 2), rng=rng) g = at.random.geometric(p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
...@@ -569,7 +561,7 @@ def test_negative_binomial(): ...@@ -569,7 +561,7 @@ def test_negative_binomial():
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([0.3, 0.7]) p = np.array([0.3, 0.7])
g = at.random.negative_binomial(n, p, size=(10_000, 2), rng=rng) g = at.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -583,7 +575,7 @@ def test_binomial(): ...@@ -583,7 +575,7 @@ def test_binomial():
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([0.3, 0.7]) p = np.array([0.3, 0.7])
g = at.random.binomial(n, p, size=(10_000, 2), rng=rng) g = at.random.binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
...@@ -598,7 +590,7 @@ def test_beta_binomial(): ...@@ -598,7 +590,7 @@ def test_beta_binomial():
a = np.array([1.5, 13]) a = np.array([1.5, 13])
b = np.array([0.5, 9]) b = np.array([0.5, 9])
g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng) g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -616,7 +608,7 @@ def test_multinomial(): ...@@ -616,7 +608,7 @@ def test_multinomial():
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng) g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -632,7 +624,7 @@ def test_vonmises_mu_outside_circle(): ...@@ -632,7 +624,7 @@ def test_vonmises_mu_outside_circle():
mu = np.array([-30, 40]) mu = np.array([-30, 40])
kappa = np.array([100, 10]) kappa = np.array([100, 10])
g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode) g_fn = random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose( np.testing.assert_allclose(
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1 samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
...@@ -678,7 +670,10 @@ def test_random_unimplemented(): ...@@ -678,7 +670,10 @@ def test_random_unimplemented():
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
compare_jax_and_py(fgraph, []) with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
compare_jax_and_py(fgraph, [])
def test_random_custom_implementation(): def test_random_custom_implementation():
...@@ -709,7 +704,10 @@ def test_random_custom_implementation(): ...@@ -709,7 +704,10 @@ def test_random_custom_implementation():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
out = nonexistentrv(rng=rng) out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, []) with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
compare_jax_and_py(fgraph, [])
def test_random_concrete_shape(): def test_random_concrete_shape():
...@@ -726,7 +724,7 @@ def test_random_concrete_shape(): ...@@ -726,7 +724,7 @@ def test_random_concrete_shape():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
x_at = at.dmatrix() x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape, rng=rng) out = at.random.normal(0, 1, size=x_at.shape, rng=rng)
jax_fn = function([x_at], out, mode=jax_mode) jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3) assert jax_fn(np.ones((2, 3))).shape == (2, 3)
...@@ -734,11 +732,7 @@ def test_random_concrete_shape_from_param(): ...@@ -734,11 +732,7 @@ def test_random_concrete_shape_from_param():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
x_at = at.dmatrix() x_at = at.dmatrix()
out = at.random.normal(x_at, 1, rng=rng) out = at.random.normal(x_at, 1, rng=rng)
with pytest.warns( jax_fn = random_function([x_at], out, mode=jax_mode)
UserWarning,
match="The RandomType SharedVariables \[.+\] will not be used"
):
jax_fn = function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3) assert jax_fn(np.ones((2, 3))).shape == (2, 3)
...@@ -757,7 +751,7 @@ def test_random_concrete_shape_subtensor(): ...@@ -757,7 +751,7 @@ def test_random_concrete_shape_subtensor():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
x_at = at.dmatrix() x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng) out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
jax_fn = function([x_at], out, mode=jax_mode) jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (3,) assert jax_fn(np.ones((2, 3))).shape == (3,)
...@@ -773,7 +767,7 @@ def test_random_concrete_shape_subtensor_tuple(): ...@@ -773,7 +767,7 @@ def test_random_concrete_shape_subtensor_tuple():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
x_at = at.dmatrix() x_at = at.dmatrix()
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng) out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
jax_fn = function([x_at], out, mode=jax_mode) jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2,) assert jax_fn(np.ones((2, 3))).shape == (2,)
...@@ -784,5 +778,5 @@ def test_random_concrete_shape_graph_input(): ...@@ -784,5 +778,5 @@ def test_random_concrete_shape_graph_input():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
size_at = at.scalar() size_at = at.scalar()
out = at.random.normal(0, 1, size=size_at, rng=rng) out = at.random.normal(0, 1, size=size_at, rng=rng)
jax_fn = function([size_at], out, mode=jax_mode) jax_fn = random_function([size_at], out, mode=jax_mode)
assert jax_fn(10).shape == (10,) assert jax_fn(10).shape == (10,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论