提交 3170c7d8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename helper function

上级 c5b96d92
...@@ -20,7 +20,7 @@ jax = pytest.importorskip("jax") ...@@ -20,7 +20,7 @@ jax = pytest.importorskip("jax")
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
def random_function(*args, **kwargs): def compile_random_function(*args, **kwargs):
with pytest.warns( with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
): ):
...@@ -35,7 +35,7 @@ def test_random_RandomStream(): ...@@ -35,7 +35,7 @@ def test_random_RandomStream():
srng = RandomStream(seed=123) srng = RandomStream(seed=123)
out = srng.normal() - srng.normal() out = srng.normal() - srng.normal()
fn = random_function([], out, mode=jax_mode) fn = compile_random_function([], out, mode=jax_mode)
jax_res_1 = fn() jax_res_1 = fn()
jax_res_2 = fn() jax_res_2 = fn()
...@@ -48,7 +48,7 @@ def test_random_updates(rng_ctor): ...@@ -48,7 +48,7 @@ def test_random_updates(rng_ctor):
rng = shared(original_value, name="original_rng", borrow=False) rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs
f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode) f = compile_random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f() assert f() != f()
# Check that original rng variable content was not overwritten when calling jax_typify # Check that original rng variable content was not overwritten when calling jax_typify
...@@ -79,7 +79,7 @@ def test_random_updates_input_storage_order(): ...@@ -79,7 +79,7 @@ def test_random_updates_input_storage_order():
# This function replaces inp by input_shared in the update expression # This function replaces inp by input_shared in the update expression
# This is what caused the RNG to appear later than inp_shared in the input_storage # This is what caused the RNG to appear later than inp_shared in the input_storage
fn = random_function( fn = compile_random_function(
inputs=[], inputs=[],
outputs=[], outputs=[],
updates={inp_shared: inp_update}, updates={inp_shared: inp_update},
...@@ -453,7 +453,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c ...@@ -453,7 +453,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
else: else:
rng = shared(np.random.RandomState(29402)) rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng) g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
g_fn = random_function(dist_params, g, mode=jax_mode) g_fn = compile_random_function(dist_params, g, mode=jax_mode)
samples = g_fn( samples = g_fn(
*[ *[
i.tag.test_value i.tag.test_value
...@@ -477,7 +477,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c ...@@ -477,7 +477,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
def test_random_bernoulli(size): def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = pt.random.bernoulli(0.5, size=(1000,) + size, rng=rng) g = pt.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
...@@ -488,7 +488,7 @@ def test_random_mvnormal(): ...@@ -488,7 +488,7 @@ def test_random_mvnormal():
mu = np.ones(4) mu = np.ones(4)
cov = np.eye(4) cov = np.eye(4)
g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng) g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
...@@ -503,7 +503,7 @@ def test_random_mvnormal(): ...@@ -503,7 +503,7 @@ def test_random_mvnormal():
def test_random_dirichlet(parameter, size): def test_random_dirichlet(parameter, size):
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = pt.random.dirichlet(parameter, size=(1000,) + size, rng=rng) g = pt.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
...@@ -513,21 +513,21 @@ def test_random_choice(): ...@@ -513,21 +513,21 @@ def test_random_choice():
num_samples = 10000 num_samples = 10000
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(4), size=num_samples, rng=rng) g = pt.random.choice(np.arange(4), size=num_samples, rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2) np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
# `replace=False` produces unique results # `replace=False` produces unique results
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(100), replace=False, size=99, rng=rng) g = pt.random.choice(np.arange(100), replace=False, size=99, rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
assert len(np.unique(samples)) == 99 assert len(np.unique(samples)) == 99
# We can pass an array with probabilities # We can pass an array with probabilities
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng) g = pt.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples, np.zeros(10)) np.testing.assert_allclose(samples, np.zeros(10))
...@@ -535,7 +535,7 @@ def test_random_choice(): ...@@ -535,7 +535,7 @@ def test_random_choice():
def test_random_categorical(): def test_random_categorical():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng) g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1) np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
...@@ -544,7 +544,7 @@ def test_random_permutation(): ...@@ -544,7 +544,7 @@ def test_random_permutation():
array = np.arange(4) array = np.arange(4)
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
g = pt.random.permutation(array, rng=rng) g = pt.random.permutation(array, rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
permuted = g_fn() permuted = g_fn()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
np.testing.assert_allclose(array, permuted) np.testing.assert_allclose(array, permuted)
...@@ -554,7 +554,7 @@ def test_random_geometric(): ...@@ -554,7 +554,7 @@ def test_random_geometric():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
p = np.array([0.3, 0.7]) p = np.array([0.3, 0.7])
g = pt.random.geometric(p, size=(10_000, 2), rng=rng) g = pt.random.geometric(p, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
...@@ -565,7 +565,7 @@ def test_negative_binomial(): ...@@ -565,7 +565,7 @@ def test_negative_binomial():
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([0.3, 0.7]) p = np.array([0.3, 0.7])
g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng) g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -579,7 +579,7 @@ def test_binomial(): ...@@ -579,7 +579,7 @@ def test_binomial():
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([0.3, 0.7]) p = np.array([0.3, 0.7])
g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng) g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
...@@ -594,7 +594,7 @@ def test_beta_binomial(): ...@@ -594,7 +594,7 @@ def test_beta_binomial():
a = np.array([1.5, 13]) a = np.array([1.5, 13])
b = np.array([0.5, 9]) b = np.array([0.5, 9])
g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng) g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -612,7 +612,7 @@ def test_multinomial(): ...@@ -612,7 +612,7 @@ def test_multinomial():
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -628,7 +628,7 @@ def test_vonmises_mu_outside_circle(): ...@@ -628,7 +628,7 @@ def test_vonmises_mu_outside_circle():
mu = np.array([-30, 40]) mu = np.array([-30, 40])
kappa = np.array([100, 10]) kappa = np.array([100, 10])
g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
g_fn = random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
np.testing.assert_allclose( np.testing.assert_allclose(
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1 samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
...@@ -728,7 +728,7 @@ def test_random_concrete_shape(): ...@@ -728,7 +728,7 @@ def test_random_concrete_shape():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
x_pt = pt.dmatrix() x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng) out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
jax_fn = random_function([x_pt], out, mode=jax_mode) jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3) assert jax_fn(np.ones((2, 3))).shape == (2, 3)
...@@ -736,7 +736,7 @@ def test_random_concrete_shape_from_param(): ...@@ -736,7 +736,7 @@ def test_random_concrete_shape_from_param():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
x_pt = pt.dmatrix() x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng) out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = random_function([x_pt], out, mode=jax_mode) jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3) assert jax_fn(np.ones((2, 3))).shape == (2, 3)
...@@ -755,7 +755,7 @@ def test_random_concrete_shape_subtensor(): ...@@ -755,7 +755,7 @@ def test_random_concrete_shape_subtensor():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
x_pt = pt.dmatrix() x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng) out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
jax_fn = random_function([x_pt], out, mode=jax_mode) jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (3,) assert jax_fn(np.ones((2, 3))).shape == (3,)
...@@ -771,7 +771,7 @@ def test_random_concrete_shape_subtensor_tuple(): ...@@ -771,7 +771,7 @@ def test_random_concrete_shape_subtensor_tuple():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
x_pt = pt.dmatrix() x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng) out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
jax_fn = random_function([x_pt], out, mode=jax_mode) jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2,) assert jax_fn(np.ones((2, 3))).shape == (2,)
...@@ -782,5 +782,5 @@ def test_random_concrete_shape_graph_input(): ...@@ -782,5 +782,5 @@ def test_random_concrete_shape_graph_input():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.RandomState(123))
size_pt = pt.scalar() size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng) out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = random_function([size_pt], out, mode=jax_mode) jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
assert jax_fn(10).shape == (10,) assert jax_fn(10).shape == (10,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论