提交 fc985340 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename rv_numpy_tester to compare_sample_values and update/add docstrings

上级 9a65fcd7
...@@ -63,6 +63,8 @@ py_mode = Mode("py", opts) ...@@ -63,6 +63,8 @@ py_mode = Mode("py", opts)
def fixed_scipy_rvs(rvs_name): def fixed_scipy_rvs(rvs_name):
"""Create a SciPy sampling function compatible with the `test_fn` argument of `compare_sample_values`."""
def _rvs(*args, size=None, **kwargs): def _rvs(*args, size=None, **kwargs):
res = getattr(stats, rvs_name).rvs(*args, size=size, **kwargs) res = getattr(stats, rvs_name).rvs(*args, size=size, **kwargs)
res = np.broadcast_to( res = np.broadcast_to(
...@@ -76,9 +78,12 @@ def fixed_scipy_rvs(rvs_name): ...@@ -76,9 +78,12 @@ def fixed_scipy_rvs(rvs_name):
return _rvs return _rvs
def rv_numpy_tester(rv, *params, rng=None, test_fn=None, **kwargs): def compare_sample_values(rv, *params, rng=None, test_fn=None, **kwargs):
"""Test for correspondence between `RandomVariable` and NumPy shape and """Test for equivalence between `RandomVariable` and NumPy/other samples.
broadcast dimensions.
An equivalently named method on a NumPy RNG object will be used, unless
`test_fn` is specified.
""" """
if rng is None: if rng is None:
rng = np.random.default_rng() rng = np.random.default_rng()
...@@ -137,11 +142,11 @@ def rv_numpy_tester(rv, *params, rng=None, test_fn=None, **kwargs): ...@@ -137,11 +142,11 @@ def rv_numpy_tester(rv, *params, rng=None, test_fn=None, **kwargs):
], ],
) )
def test_uniform_samples(u, l, size): def test_uniform_samples(u, l, size):
rv_numpy_tester(uniform, u, l, size=size) compare_sample_values(uniform, u, l, size=size)
def test_uniform_default_args(): def test_uniform_default_args():
rv_numpy_tester(uniform) compare_sample_values(uniform)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -168,7 +173,7 @@ def test_uniform_default_args(): ...@@ -168,7 +173,7 @@ def test_uniform_default_args():
], ],
) )
def test_triangular_samples(left, mode, right, size): def test_triangular_samples(left, mode, right, size):
rv_numpy_tester(triangular, left, mode, right, size=size) compare_sample_values(triangular, left, mode, right, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -184,7 +189,7 @@ def test_triangular_samples(left, mode, right, size): ...@@ -184,7 +189,7 @@ def test_triangular_samples(left, mode, right, size):
], ],
) )
def test_beta_samples(a, b, size): def test_beta_samples(a, b, size):
rv_numpy_tester(beta, a, b, size=size) compare_sample_values(beta, a, b, size=size)
M_at = iscalar("M") M_at = iscalar("M")
...@@ -286,11 +291,11 @@ def test_normal_ShapeFeature(): ...@@ -286,11 +291,11 @@ def test_normal_ShapeFeature():
], ],
) )
def test_normal_samples(mean, sigma, size): def test_normal_samples(mean, sigma, size):
rv_numpy_tester(normal, mean, sigma, size=size) compare_sample_values(normal, mean, sigma, size=size)
def test_normal_default_args(): def test_normal_default_args():
rv_numpy_tester(standard_normal) compare_sample_values(standard_normal)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -306,7 +311,7 @@ def test_normal_default_args(): ...@@ -306,7 +311,7 @@ def test_normal_default_args():
], ],
) )
def test_halfnormal_samples(mean, sigma, size): def test_halfnormal_samples(mean, sigma, size):
rv_numpy_tester( compare_sample_values(
halfnormal, mean, sigma, size=size, test_fn=fixed_scipy_rvs("halfnorm") halfnormal, mean, sigma, size=size, test_fn=fixed_scipy_rvs("halfnorm")
) )
...@@ -324,7 +329,7 @@ def test_halfnormal_samples(mean, sigma, size): ...@@ -324,7 +329,7 @@ def test_halfnormal_samples(mean, sigma, size):
], ],
) )
def test_lognormal_samples(mean, sigma, size): def test_lognormal_samples(mean, sigma, size):
rv_numpy_tester(lognormal, mean, sigma, size=size) compare_sample_values(lognormal, mean, sigma, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -345,7 +350,7 @@ def test_gamma_samples(a, b, size): ...@@ -345,7 +350,7 @@ def test_gamma_samples(a, b, size):
def test_fn(shape, rate, **kwargs): def test_fn(shape, rate, **kwargs):
return gamma_test_fn(shape, scale=1.0 / rate, **kwargs) return gamma_test_fn(shape, scale=1.0 / rate, **kwargs)
rv_numpy_tester( compare_sample_values(
gamma, gamma,
a, a,
b, b,
...@@ -363,7 +368,7 @@ def test_gamma_samples(a, b, size): ...@@ -363,7 +368,7 @@ def test_gamma_samples(a, b, size):
], ],
) )
def test_chisquare_samples(df, size): def test_chisquare_samples(df, size):
rv_numpy_tester(chisquare, df, size=size, test_fn=fixed_scipy_rvs("chi2")) compare_sample_values(chisquare, df, size=size, test_fn=fixed_scipy_rvs("chi2"))
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -379,7 +384,9 @@ def test_chisquare_samples(df, size): ...@@ -379,7 +384,9 @@ def test_chisquare_samples(df, size):
], ],
) )
def test_gumbel_samples(mu, beta, size): def test_gumbel_samples(mu, beta, size):
rv_numpy_tester(gumbel, mu, beta, size=size, test_fn=fixed_scipy_rvs("gumbel_r")) compare_sample_values(
gumbel, mu, beta, size=size, test_fn=fixed_scipy_rvs("gumbel_r")
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -394,11 +401,11 @@ def test_gumbel_samples(mu, beta, size): ...@@ -394,11 +401,11 @@ def test_gumbel_samples(mu, beta, size):
], ],
) )
def test_exponential_samples(lam, size): def test_exponential_samples(lam, size):
rv_numpy_tester(exponential, lam, size=size) compare_sample_values(exponential, lam, size=size)
def test_exponential_default_args(): def test_exponential_default_args():
rv_numpy_tester(exponential) compare_sample_values(exponential)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -413,7 +420,7 @@ def test_exponential_default_args(): ...@@ -413,7 +420,7 @@ def test_exponential_default_args():
], ],
) )
def test_weibull_samples(alpha, size): def test_weibull_samples(alpha, size):
rv_numpy_tester(weibull, alpha, size=size) compare_sample_values(weibull, alpha, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -429,11 +436,11 @@ def test_weibull_samples(alpha, size): ...@@ -429,11 +436,11 @@ def test_weibull_samples(alpha, size):
], ],
) )
def test_logistic_samples(loc, scale, size): def test_logistic_samples(loc, scale, size):
rv_numpy_tester(logistic, loc, scale, size=size) compare_sample_values(logistic, loc, scale, size=size)
def test_logistic_default_args(): def test_logistic_default_args():
rv_numpy_tester(logistic) compare_sample_values(logistic)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -453,7 +460,7 @@ def test_logistic_default_args(): ...@@ -453,7 +460,7 @@ def test_logistic_default_args():
], ],
) )
def test_vonmises_samples(mu, kappa, size): def test_vonmises_samples(mu, kappa, size):
rv_numpy_tester(vonmises, mu, kappa, size=size) compare_sample_values(vonmises, mu, kappa, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -468,7 +475,7 @@ def test_vonmises_samples(mu, kappa, size): ...@@ -468,7 +475,7 @@ def test_vonmises_samples(mu, kappa, size):
], ],
) )
def test_pareto_samples(alpha, size): def test_pareto_samples(alpha, size):
rv_numpy_tester(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto")) compare_sample_values(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto"))
def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
...@@ -561,11 +568,13 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): ...@@ -561,11 +568,13 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
], ],
) )
def test_mvnormal_samples(mu, cov, size): def test_mvnormal_samples(mu, cov, size):
rv_numpy_tester(multivariate_normal, mu, cov, size=size, test_fn=mvnormal_test_fn) compare_sample_values(
multivariate_normal, mu, cov, size=size, test_fn=mvnormal_test_fn
)
def test_mvnormal_default_args(): def test_mvnormal_default_args():
rv_numpy_tester(multivariate_normal, test_fn=mvnormal_test_fn) compare_sample_values(multivariate_normal, test_fn=mvnormal_test_fn)
with pytest.raises(ValueError, match="shape mismatch.*"): with pytest.raises(ValueError, match="shape mismatch.*"):
multivariate_normal.rng_fn( multivariate_normal.rng_fn(
...@@ -637,7 +646,7 @@ def test_dirichlet_samples(alphas, size): ...@@ -637,7 +646,7 @@ def test_dirichlet_samples(alphas, size):
size = () size = ()
return dirichlet.rng_fn(random_state, alphas, size) return dirichlet.rng_fn(random_state, alphas, size)
rv_numpy_tester(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn) compare_sample_values(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn)
def test_dirichlet_rng(): def test_dirichlet_rng():
...@@ -723,11 +732,11 @@ def test_dirichlet_ShapeFeature(): ...@@ -723,11 +732,11 @@ def test_dirichlet_ShapeFeature():
], ],
) )
def test_poisson_samples(lam, size): def test_poisson_samples(lam, size):
rv_numpy_tester(poisson, lam, size=size) compare_sample_values(poisson, lam, size=size)
def test_poisson_default_args(): def test_poisson_default_args():
rv_numpy_tester(poisson) compare_sample_values(poisson)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -742,7 +751,7 @@ def test_poisson_default_args(): ...@@ -742,7 +751,7 @@ def test_poisson_default_args():
], ],
) )
def test_geometric_samples(p, size): def test_geometric_samples(p, size):
rv_numpy_tester(geometric, p, size=size) compare_sample_values(geometric, p, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -769,7 +778,7 @@ def test_geometric_samples(p, size): ...@@ -769,7 +778,7 @@ def test_geometric_samples(p, size):
], ],
) )
def test_hypergeometric_samples(ngood, nbad, nsample, size): def test_hypergeometric_samples(ngood, nbad, nsample, size):
rv_numpy_tester(hypergeometric, ngood, nbad, nsample, size=size) compare_sample_values(hypergeometric, ngood, nbad, nsample, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -786,11 +795,13 @@ def test_hypergeometric_samples(ngood, nbad, nsample, size): ...@@ -786,11 +795,13 @@ def test_hypergeometric_samples(ngood, nbad, nsample, size):
], ],
) )
def test_cauchy_samples(loc, scale, size): def test_cauchy_samples(loc, scale, size):
rv_numpy_tester(cauchy, loc, scale, size=size, test_fn=fixed_scipy_rvs("cauchy")) compare_sample_values(
cauchy, loc, scale, size=size, test_fn=fixed_scipy_rvs("cauchy")
)
def test_cauchy_default_args(): def test_cauchy_default_args():
rv_numpy_tester(cauchy, test_fn=stats.cauchy.rvs) compare_sample_values(cauchy, test_fn=stats.cauchy.rvs)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -807,13 +818,13 @@ def test_cauchy_default_args(): ...@@ -807,13 +818,13 @@ def test_cauchy_default_args():
], ],
) )
def test_halfcauchy_samples(loc, scale, size): def test_halfcauchy_samples(loc, scale, size):
rv_numpy_tester( compare_sample_values(
halfcauchy, loc, scale, size=size, test_fn=fixed_scipy_rvs("halfcauchy") halfcauchy, loc, scale, size=size, test_fn=fixed_scipy_rvs("halfcauchy")
) )
def test_halfcauchy_default_args(): def test_halfcauchy_default_args():
rv_numpy_tester(halfcauchy, test_fn=stats.halfcauchy.rvs) compare_sample_values(halfcauchy, test_fn=stats.halfcauchy.rvs)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -830,7 +841,7 @@ def test_halfcauchy_default_args(): ...@@ -830,7 +841,7 @@ def test_halfcauchy_default_args():
], ],
) )
def test_invgamma_samples(loc, scale, size): def test_invgamma_samples(loc, scale, size):
rv_numpy_tester( compare_sample_values(
invgamma, invgamma,
loc, loc,
scale, scale,
...@@ -855,7 +866,7 @@ def test_invgamma_samples(loc, scale, size): ...@@ -855,7 +866,7 @@ def test_invgamma_samples(loc, scale, size):
], ],
) )
def test_wald_samples(mean, scale, size): def test_wald_samples(mean, scale, size):
rv_numpy_tester(wald, mean, scale, size=size) compare_sample_values(wald, mean, scale, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -888,7 +899,7 @@ def test_wald_samples(mean, scale, size): ...@@ -888,7 +899,7 @@ def test_wald_samples(mean, scale, size):
], ],
) )
def test_truncexpon_samples(b, loc, scale, size): def test_truncexpon_samples(b, loc, scale, size):
rv_numpy_tester( compare_sample_values(
truncexpon, truncexpon,
b, b,
loc, loc,
...@@ -922,7 +933,7 @@ def test_truncexpon_samples(b, loc, scale, size): ...@@ -922,7 +933,7 @@ def test_truncexpon_samples(b, loc, scale, size):
], ],
) )
def test_bernoulli_samples(p, size): def test_bernoulli_samples(p, size):
rv_numpy_tester( compare_sample_values(
bernoulli, bernoulli,
p, p,
size=size, size=size,
...@@ -958,7 +969,7 @@ def test_bernoulli_samples(p, size): ...@@ -958,7 +969,7 @@ def test_bernoulli_samples(p, size):
], ],
) )
def test_laplace_samples(loc, scale, size): def test_laplace_samples(loc, scale, size):
rv_numpy_tester(laplace, loc, scale, size=size) compare_sample_values(laplace, loc, scale, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -987,7 +998,7 @@ def test_laplace_samples(loc, scale, size): ...@@ -987,7 +998,7 @@ def test_laplace_samples(loc, scale, size):
], ],
) )
def test_binomial_samples(M, p, size): def test_binomial_samples(M, p, size):
rv_numpy_tester(binomial, M, p, size=size) compare_sample_values(binomial, M, p, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1016,7 +1027,7 @@ def test_binomial_samples(M, p, size): ...@@ -1016,7 +1027,7 @@ def test_binomial_samples(M, p, size):
], ],
) )
def test_nbinom_samples(M, p, size): def test_nbinom_samples(M, p, size):
rv_numpy_tester( compare_sample_values(
nbinom, nbinom,
M, M,
p, p,
...@@ -1057,7 +1068,7 @@ def test_nbinom_samples(M, p, size): ...@@ -1057,7 +1068,7 @@ def test_nbinom_samples(M, p, size):
], ],
) )
def test_betabinom_samples(M, a, p, size): def test_betabinom_samples(M, a, p, size):
rv_numpy_tester( compare_sample_values(
betabinom, betabinom,
M, M,
a, a,
...@@ -1114,7 +1125,7 @@ def test_betabinom_samples(M, a, p, size): ...@@ -1114,7 +1125,7 @@ def test_betabinom_samples(M, a, p, size):
) )
def test_multinomial_samples(M, p, size, test_fn): def test_multinomial_samples(M, p, size, test_fn):
rng = np.random.default_rng(1234) rng = np.random.default_rng(1234)
rv_numpy_tester( compare_sample_values(
multinomial, multinomial,
M, M,
p, p,
...@@ -1162,7 +1173,7 @@ def test_categorical_samples(p, size, test_fn): ...@@ -1162,7 +1173,7 @@ def test_categorical_samples(p, size, test_fn):
p = p / p.sum(axis=-1) p = p / p.sum(axis=-1)
rng = np.random.default_rng(232) rng = np.random.default_rng(232)
rv_numpy_tester( compare_sample_values(
categorical, categorical,
p, p,
size=size, size=size,
...@@ -1187,14 +1198,14 @@ def test_randint_samples(): ...@@ -1187,14 +1198,14 @@ def test_randint_samples():
randint(10, rng=shared(np.random.default_rng())) randint(10, rng=shared(np.random.default_rng()))
rng = np.random.RandomState(2313) rng = np.random.RandomState(2313)
rv_numpy_tester(randint, 10, None, rng=rng) compare_sample_values(randint, 10, None, rng=rng)
rv_numpy_tester(randint, 0, 1, rng=rng) compare_sample_values(randint, 0, 1, rng=rng)
rv_numpy_tester(randint, 0, 1, size=[3], rng=rng) compare_sample_values(randint, 0, 1, size=[3], rng=rng)
rv_numpy_tester(randint, [0, 1, 2], 5, rng=rng) compare_sample_values(randint, [0, 1, 2], 5, rng=rng)
rv_numpy_tester(randint, [0, 1, 2], 5, size=[3, 3], rng=rng) compare_sample_values(randint, [0, 1, 2], 5, size=[3, 3], rng=rng)
rv_numpy_tester(randint, [0], [5], size=[1], rng=rng) compare_sample_values(randint, [0], [5], size=[1], rng=rng)
rv_numpy_tester(randint, at.as_tensor_variable([-1]), [1], size=[1], rng=rng) compare_sample_values(randint, at.as_tensor_variable([-1]), [1], size=[1], rng=rng)
rv_numpy_tester( compare_sample_values(
randint, randint,
at.as_tensor_variable([-1]), at.as_tensor_variable([-1]),
[1], [1],
...@@ -1209,14 +1220,14 @@ def test_integers_samples(): ...@@ -1209,14 +1220,14 @@ def test_integers_samples():
integers(10, rng=shared(np.random.RandomState())) integers(10, rng=shared(np.random.RandomState()))
rng = np.random.default_rng(2313) rng = np.random.default_rng(2313)
rv_numpy_tester(integers, 10, None, rng=rng) compare_sample_values(integers, 10, None, rng=rng)
rv_numpy_tester(integers, 0, 1, rng=rng) compare_sample_values(integers, 0, 1, rng=rng)
rv_numpy_tester(integers, 0, 1, size=[3], rng=rng) compare_sample_values(integers, 0, 1, size=[3], rng=rng)
rv_numpy_tester(integers, [0, 1, 2], 5, rng=rng) compare_sample_values(integers, [0, 1, 2], 5, rng=rng)
rv_numpy_tester(integers, [0, 1, 2], 5, size=[3, 3], rng=rng) compare_sample_values(integers, [0, 1, 2], 5, size=[3, 3], rng=rng)
rv_numpy_tester(integers, [0], [5], size=[1], rng=rng) compare_sample_values(integers, [0], [5], size=[1], rng=rng)
rv_numpy_tester(integers, at.as_tensor_variable([-1]), [1], size=[1], rng=rng) compare_sample_values(integers, at.as_tensor_variable([-1]), [1], size=[1], rng=rng)
rv_numpy_tester( compare_sample_values(
integers, integers,
at.as_tensor_variable([-1]), at.as_tensor_variable([-1]),
[1], [1],
...@@ -1229,28 +1240,30 @@ def test_choice_samples(): ...@@ -1229,28 +1240,30 @@ def test_choice_samples():
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
choice._supp_shape_from_params(np.asarray(5)) choice._supp_shape_from_params(np.asarray(5))
rv_numpy_tester(choice, np.asarray([5])) compare_sample_values(choice, np.asarray([5]))
rv_numpy_tester(choice, np.array([1.0, 5.0], dtype=config.floatX)) compare_sample_values(choice, np.array([1.0, 5.0], dtype=config.floatX))
rv_numpy_tester(choice, np.asarray([5]), 3) compare_sample_values(choice, np.asarray([5]), 3)
with pytest.raises(ValueError): with pytest.raises(ValueError):
rv_numpy_tester(choice, np.array([[1, 2], [3, 4]])) compare_sample_values(choice, np.array([[1, 2], [3, 4]]))
rv_numpy_tester(choice, [1, 2, 3], 1) compare_sample_values(choice, [1, 2, 3], 1)
rv_numpy_tester(choice, [1, 2, 3], 1, p=at.as_tensor([1 / 3.0, 1 / 3.0, 1 / 3.0])) compare_sample_values(
rv_numpy_tester(choice, [1, 2, 3], (10, 2), replace=True) choice, [1, 2, 3], 1, p=at.as_tensor([1 / 3.0, 1 / 3.0, 1 / 3.0])
rv_numpy_tester(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True) )
compare_sample_values(choice, [1, 2, 3], (10, 2), replace=True)
compare_sample_values(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True)
def test_permutation_samples(): def test_permutation_samples():
rv_numpy_tester( compare_sample_values(
permutation, permutation,
np.asarray(5), np.asarray(5),
test_fn=lambda x, random_state=None: random_state.permutation(x.item()), test_fn=lambda x, random_state=None: random_state.permutation(x.item()),
) )
rv_numpy_tester(permutation, [1, 2, 3]) compare_sample_values(permutation, [1, 2, 3])
rv_numpy_tester(permutation, [[1, 2], [3, 4]]) compare_sample_values(permutation, [[1, 2], [3, 4]])
rv_numpy_tester(permutation, np.array([1.0, 2.0, 3.0], dtype=config.floatX)) compare_sample_values(permutation, np.array([1.0, 2.0, 3.0], dtype=config.floatX))
@config.change_flags(compute_test_value="off") @config.change_flags(compute_test_value="off")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论