提交 fc985340 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename rv_numpy_tester to compare_sample_values and update/add docstrings

上级 9a65fcd7
......@@ -63,6 +63,8 @@ py_mode = Mode("py", opts)
def fixed_scipy_rvs(rvs_name):
"""Create a SciPy sampling function compatible with the `test_fn` argument of `compare_sample_values`."""
def _rvs(*args, size=None, **kwargs):
res = getattr(stats, rvs_name).rvs(*args, size=size, **kwargs)
res = np.broadcast_to(
......@@ -76,9 +78,12 @@ def fixed_scipy_rvs(rvs_name):
return _rvs
def rv_numpy_tester(rv, *params, rng=None, test_fn=None, **kwargs):
"""Test for correspondence between `RandomVariable` and NumPy shape and
broadcast dimensions.
def compare_sample_values(rv, *params, rng=None, test_fn=None, **kwargs):
"""Test for equivalence between `RandomVariable` and NumPy/other samples.
An equivalently named method on a NumPy RNG object will be used, unless
`test_fn` is specified.
"""
if rng is None:
rng = np.random.default_rng()
......@@ -137,11 +142,11 @@ def rv_numpy_tester(rv, *params, rng=None, test_fn=None, **kwargs):
],
)
def test_uniform_samples(u, l, size):
rv_numpy_tester(uniform, u, l, size=size)
compare_sample_values(uniform, u, l, size=size)
def test_uniform_default_args():
rv_numpy_tester(uniform)
compare_sample_values(uniform)
@pytest.mark.parametrize(
......@@ -168,7 +173,7 @@ def test_uniform_default_args():
],
)
def test_triangular_samples(left, mode, right, size):
rv_numpy_tester(triangular, left, mode, right, size=size)
compare_sample_values(triangular, left, mode, right, size=size)
@pytest.mark.parametrize(
......@@ -184,7 +189,7 @@ def test_triangular_samples(left, mode, right, size):
],
)
def test_beta_samples(a, b, size):
rv_numpy_tester(beta, a, b, size=size)
compare_sample_values(beta, a, b, size=size)
M_at = iscalar("M")
......@@ -286,11 +291,11 @@ def test_normal_ShapeFeature():
],
)
def test_normal_samples(mean, sigma, size):
rv_numpy_tester(normal, mean, sigma, size=size)
compare_sample_values(normal, mean, sigma, size=size)
def test_normal_default_args():
rv_numpy_tester(standard_normal)
compare_sample_values(standard_normal)
@pytest.mark.parametrize(
......@@ -306,7 +311,7 @@ def test_normal_default_args():
],
)
def test_halfnormal_samples(mean, sigma, size):
rv_numpy_tester(
compare_sample_values(
halfnormal, mean, sigma, size=size, test_fn=fixed_scipy_rvs("halfnorm")
)
......@@ -324,7 +329,7 @@ def test_halfnormal_samples(mean, sigma, size):
],
)
def test_lognormal_samples(mean, sigma, size):
rv_numpy_tester(lognormal, mean, sigma, size=size)
compare_sample_values(lognormal, mean, sigma, size=size)
@pytest.mark.parametrize(
......@@ -345,7 +350,7 @@ def test_gamma_samples(a, b, size):
def test_fn(shape, rate, **kwargs):
return gamma_test_fn(shape, scale=1.0 / rate, **kwargs)
rv_numpy_tester(
compare_sample_values(
gamma,
a,
b,
......@@ -363,7 +368,7 @@ def test_gamma_samples(a, b, size):
],
)
def test_chisquare_samples(df, size):
rv_numpy_tester(chisquare, df, size=size, test_fn=fixed_scipy_rvs("chi2"))
compare_sample_values(chisquare, df, size=size, test_fn=fixed_scipy_rvs("chi2"))
@pytest.mark.parametrize(
......@@ -379,7 +384,9 @@ def test_chisquare_samples(df, size):
],
)
def test_gumbel_samples(mu, beta, size):
rv_numpy_tester(gumbel, mu, beta, size=size, test_fn=fixed_scipy_rvs("gumbel_r"))
compare_sample_values(
gumbel, mu, beta, size=size, test_fn=fixed_scipy_rvs("gumbel_r")
)
@pytest.mark.parametrize(
......@@ -394,11 +401,11 @@ def test_gumbel_samples(mu, beta, size):
],
)
def test_exponential_samples(lam, size):
rv_numpy_tester(exponential, lam, size=size)
compare_sample_values(exponential, lam, size=size)
def test_exponential_default_args():
rv_numpy_tester(exponential)
compare_sample_values(exponential)
@pytest.mark.parametrize(
......@@ -413,7 +420,7 @@ def test_exponential_default_args():
],
)
def test_weibull_samples(alpha, size):
rv_numpy_tester(weibull, alpha, size=size)
compare_sample_values(weibull, alpha, size=size)
@pytest.mark.parametrize(
......@@ -429,11 +436,11 @@ def test_weibull_samples(alpha, size):
],
)
def test_logistic_samples(loc, scale, size):
rv_numpy_tester(logistic, loc, scale, size=size)
compare_sample_values(logistic, loc, scale, size=size)
def test_logistic_default_args():
rv_numpy_tester(logistic)
compare_sample_values(logistic)
@pytest.mark.parametrize(
......@@ -453,7 +460,7 @@ def test_logistic_default_args():
],
)
def test_vonmises_samples(mu, kappa, size):
rv_numpy_tester(vonmises, mu, kappa, size=size)
compare_sample_values(vonmises, mu, kappa, size=size)
@pytest.mark.parametrize(
......@@ -468,7 +475,7 @@ def test_vonmises_samples(mu, kappa, size):
],
)
def test_pareto_samples(alpha, size):
rv_numpy_tester(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto"))
compare_sample_values(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto"))
def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
......@@ -561,11 +568,13 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
],
)
def test_mvnormal_samples(mu, cov, size):
rv_numpy_tester(multivariate_normal, mu, cov, size=size, test_fn=mvnormal_test_fn)
compare_sample_values(
multivariate_normal, mu, cov, size=size, test_fn=mvnormal_test_fn
)
def test_mvnormal_default_args():
rv_numpy_tester(multivariate_normal, test_fn=mvnormal_test_fn)
compare_sample_values(multivariate_normal, test_fn=mvnormal_test_fn)
with pytest.raises(ValueError, match="shape mismatch.*"):
multivariate_normal.rng_fn(
......@@ -637,7 +646,7 @@ def test_dirichlet_samples(alphas, size):
size = ()
return dirichlet.rng_fn(random_state, alphas, size)
rv_numpy_tester(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn)
compare_sample_values(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn)
def test_dirichlet_rng():
......@@ -723,11 +732,11 @@ def test_dirichlet_ShapeFeature():
],
)
def test_poisson_samples(lam, size):
rv_numpy_tester(poisson, lam, size=size)
compare_sample_values(poisson, lam, size=size)
def test_poisson_default_args():
rv_numpy_tester(poisson)
compare_sample_values(poisson)
@pytest.mark.parametrize(
......@@ -742,7 +751,7 @@ def test_poisson_default_args():
],
)
def test_geometric_samples(p, size):
rv_numpy_tester(geometric, p, size=size)
compare_sample_values(geometric, p, size=size)
@pytest.mark.parametrize(
......@@ -769,7 +778,7 @@ def test_geometric_samples(p, size):
],
)
def test_hypergeometric_samples(ngood, nbad, nsample, size):
rv_numpy_tester(hypergeometric, ngood, nbad, nsample, size=size)
compare_sample_values(hypergeometric, ngood, nbad, nsample, size=size)
@pytest.mark.parametrize(
......@@ -786,11 +795,13 @@ def test_hypergeometric_samples(ngood, nbad, nsample, size):
],
)
def test_cauchy_samples(loc, scale, size):
rv_numpy_tester(cauchy, loc, scale, size=size, test_fn=fixed_scipy_rvs("cauchy"))
compare_sample_values(
cauchy, loc, scale, size=size, test_fn=fixed_scipy_rvs("cauchy")
)
def test_cauchy_default_args():
rv_numpy_tester(cauchy, test_fn=stats.cauchy.rvs)
compare_sample_values(cauchy, test_fn=stats.cauchy.rvs)
@pytest.mark.parametrize(
......@@ -807,13 +818,13 @@ def test_cauchy_default_args():
],
)
def test_halfcauchy_samples(loc, scale, size):
rv_numpy_tester(
compare_sample_values(
halfcauchy, loc, scale, size=size, test_fn=fixed_scipy_rvs("halfcauchy")
)
def test_halfcauchy_default_args():
rv_numpy_tester(halfcauchy, test_fn=stats.halfcauchy.rvs)
compare_sample_values(halfcauchy, test_fn=stats.halfcauchy.rvs)
@pytest.mark.parametrize(
......@@ -830,7 +841,7 @@ def test_halfcauchy_default_args():
],
)
def test_invgamma_samples(loc, scale, size):
rv_numpy_tester(
compare_sample_values(
invgamma,
loc,
scale,
......@@ -855,7 +866,7 @@ def test_invgamma_samples(loc, scale, size):
],
)
def test_wald_samples(mean, scale, size):
rv_numpy_tester(wald, mean, scale, size=size)
compare_sample_values(wald, mean, scale, size=size)
@pytest.mark.parametrize(
......@@ -888,7 +899,7 @@ def test_wald_samples(mean, scale, size):
],
)
def test_truncexpon_samples(b, loc, scale, size):
rv_numpy_tester(
compare_sample_values(
truncexpon,
b,
loc,
......@@ -922,7 +933,7 @@ def test_truncexpon_samples(b, loc, scale, size):
],
)
def test_bernoulli_samples(p, size):
rv_numpy_tester(
compare_sample_values(
bernoulli,
p,
size=size,
......@@ -958,7 +969,7 @@ def test_bernoulli_samples(p, size):
],
)
def test_laplace_samples(loc, scale, size):
rv_numpy_tester(laplace, loc, scale, size=size)
compare_sample_values(laplace, loc, scale, size=size)
@pytest.mark.parametrize(
......@@ -987,7 +998,7 @@ def test_laplace_samples(loc, scale, size):
],
)
def test_binomial_samples(M, p, size):
rv_numpy_tester(binomial, M, p, size=size)
compare_sample_values(binomial, M, p, size=size)
@pytest.mark.parametrize(
......@@ -1016,7 +1027,7 @@ def test_binomial_samples(M, p, size):
],
)
def test_nbinom_samples(M, p, size):
rv_numpy_tester(
compare_sample_values(
nbinom,
M,
p,
......@@ -1057,7 +1068,7 @@ def test_nbinom_samples(M, p, size):
],
)
def test_betabinom_samples(M, a, p, size):
rv_numpy_tester(
compare_sample_values(
betabinom,
M,
a,
......@@ -1114,7 +1125,7 @@ def test_betabinom_samples(M, a, p, size):
)
def test_multinomial_samples(M, p, size, test_fn):
rng = np.random.default_rng(1234)
rv_numpy_tester(
compare_sample_values(
multinomial,
M,
p,
......@@ -1162,7 +1173,7 @@ def test_categorical_samples(p, size, test_fn):
p = p / p.sum(axis=-1)
rng = np.random.default_rng(232)
rv_numpy_tester(
compare_sample_values(
categorical,
p,
size=size,
......@@ -1187,14 +1198,14 @@ def test_randint_samples():
randint(10, rng=shared(np.random.default_rng()))
rng = np.random.RandomState(2313)
rv_numpy_tester(randint, 10, None, rng=rng)
rv_numpy_tester(randint, 0, 1, rng=rng)
rv_numpy_tester(randint, 0, 1, size=[3], rng=rng)
rv_numpy_tester(randint, [0, 1, 2], 5, rng=rng)
rv_numpy_tester(randint, [0, 1, 2], 5, size=[3, 3], rng=rng)
rv_numpy_tester(randint, [0], [5], size=[1], rng=rng)
rv_numpy_tester(randint, at.as_tensor_variable([-1]), [1], size=[1], rng=rng)
rv_numpy_tester(
compare_sample_values(randint, 10, None, rng=rng)
compare_sample_values(randint, 0, 1, rng=rng)
compare_sample_values(randint, 0, 1, size=[3], rng=rng)
compare_sample_values(randint, [0, 1, 2], 5, rng=rng)
compare_sample_values(randint, [0, 1, 2], 5, size=[3, 3], rng=rng)
compare_sample_values(randint, [0], [5], size=[1], rng=rng)
compare_sample_values(randint, at.as_tensor_variable([-1]), [1], size=[1], rng=rng)
compare_sample_values(
randint,
at.as_tensor_variable([-1]),
[1],
......@@ -1209,14 +1220,14 @@ def test_integers_samples():
integers(10, rng=shared(np.random.RandomState()))
rng = np.random.default_rng(2313)
rv_numpy_tester(integers, 10, None, rng=rng)
rv_numpy_tester(integers, 0, 1, rng=rng)
rv_numpy_tester(integers, 0, 1, size=[3], rng=rng)
rv_numpy_tester(integers, [0, 1, 2], 5, rng=rng)
rv_numpy_tester(integers, [0, 1, 2], 5, size=[3, 3], rng=rng)
rv_numpy_tester(integers, [0], [5], size=[1], rng=rng)
rv_numpy_tester(integers, at.as_tensor_variable([-1]), [1], size=[1], rng=rng)
rv_numpy_tester(
compare_sample_values(integers, 10, None, rng=rng)
compare_sample_values(integers, 0, 1, rng=rng)
compare_sample_values(integers, 0, 1, size=[3], rng=rng)
compare_sample_values(integers, [0, 1, 2], 5, rng=rng)
compare_sample_values(integers, [0, 1, 2], 5, size=[3, 3], rng=rng)
compare_sample_values(integers, [0], [5], size=[1], rng=rng)
compare_sample_values(integers, at.as_tensor_variable([-1]), [1], size=[1], rng=rng)
compare_sample_values(
integers,
at.as_tensor_variable([-1]),
[1],
......@@ -1229,28 +1240,30 @@ def test_choice_samples():
with pytest.raises(NotImplementedError):
choice._supp_shape_from_params(np.asarray(5))
rv_numpy_tester(choice, np.asarray([5]))
rv_numpy_tester(choice, np.array([1.0, 5.0], dtype=config.floatX))
rv_numpy_tester(choice, np.asarray([5]), 3)
compare_sample_values(choice, np.asarray([5]))
compare_sample_values(choice, np.array([1.0, 5.0], dtype=config.floatX))
compare_sample_values(choice, np.asarray([5]), 3)
with pytest.raises(ValueError):
rv_numpy_tester(choice, np.array([[1, 2], [3, 4]]))
compare_sample_values(choice, np.array([[1, 2], [3, 4]]))
rv_numpy_tester(choice, [1, 2, 3], 1)
rv_numpy_tester(choice, [1, 2, 3], 1, p=at.as_tensor([1 / 3.0, 1 / 3.0, 1 / 3.0]))
rv_numpy_tester(choice, [1, 2, 3], (10, 2), replace=True)
rv_numpy_tester(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True)
compare_sample_values(choice, [1, 2, 3], 1)
compare_sample_values(
choice, [1, 2, 3], 1, p=at.as_tensor([1 / 3.0, 1 / 3.0, 1 / 3.0])
)
compare_sample_values(choice, [1, 2, 3], (10, 2), replace=True)
compare_sample_values(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True)
def test_permutation_samples():
rv_numpy_tester(
compare_sample_values(
permutation,
np.asarray(5),
test_fn=lambda x, random_state=None: random_state.permutation(x.item()),
)
rv_numpy_tester(permutation, [1, 2, 3])
rv_numpy_tester(permutation, [[1, 2], [3, 4]])
rv_numpy_tester(permutation, np.array([1.0, 2.0, 3.0], dtype=config.floatX))
compare_sample_values(permutation, [1, 2, 3])
compare_sample_values(permutation, [[1, 2], [3, 4]])
compare_sample_values(permutation, np.array([1.0, 2.0, 3.0], dtype=config.floatX))
@config.change_flags(compute_test_value="off")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论