提交 a71c1aca authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use distribution tests on Numba samplers that don't match NumPy

上级 6b062f0a
...@@ -5,6 +5,7 @@ from unittest import mock ...@@ -5,6 +5,7 @@ from unittest import mock
import numba import numba
import numpy as np import numpy as np
import pytest import pytest
import scipy.stats as stats
import aesara.scalar as aes import aesara.scalar as aes
import aesara.scalar.basic as aesb import aesara.scalar.basic as aesb
...@@ -2731,21 +2732,6 @@ def test_shared(): ...@@ -2731,21 +2732,6 @@ def test_shared():
], ],
at.as_tensor([3, 2]), at.as_tensor([3, 2]),
), ),
pytest.param(
aer.beta,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
at.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
( (
aer.lognormal, aer.lognormal,
[ [
...@@ -2760,32 +2746,6 @@ def test_shared(): ...@@ -2760,32 +2746,6 @@ def test_shared():
], ],
at.as_tensor([3, 2]), at.as_tensor([3, 2]),
), ),
pytest.param(
aer.gamma,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
at.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
pytest.param(
aer.chisquare,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
at.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
pytest.param( pytest.param(
aer.pareto, aer.pareto,
[ [
...@@ -2797,21 +2757,6 @@ def test_shared(): ...@@ -2797,21 +2757,6 @@ def test_shared():
at.as_tensor([3, 2]), at.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Not implemented"), marks=pytest.mark.xfail(reason="Not implemented"),
), ),
pytest.param(
aer.gumbel,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
at.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
( (
aer.exponential, aer.exponential,
[ [
...@@ -2846,21 +2791,6 @@ def test_shared(): ...@@ -2846,21 +2791,6 @@ def test_shared():
], ],
at.as_tensor([3, 2]), at.as_tensor([3, 2]),
), ),
pytest.param(
aer.vonmises,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
at.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
( (
aer.geometric, aer.geometric,
[ [
...@@ -2889,21 +2819,6 @@ def test_shared(): ...@@ -2889,21 +2819,6 @@ def test_shared():
], ],
at.as_tensor([3, 2]), at.as_tensor([3, 2]),
), ),
pytest.param(
aer.cauchy,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
at.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
( (
aer.wald, aer.wald,
[ [
...@@ -2946,20 +2861,6 @@ def test_shared(): ...@@ -2946,20 +2861,6 @@ def test_shared():
], ],
at.as_tensor([3, 2]), at.as_tensor([3, 2]),
), ),
(
aer.negative_binomial,
[
set_test_value(
at.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
at.dscalar(),
np.array(0.9, dtype=np.float64),
),
],
at.as_tensor([3, 2]),
),
( (
aer.normal, aer.normal,
[ [
...@@ -3040,7 +2941,8 @@ def test_shared(): ...@@ -3040,7 +2941,8 @@ def test_shared():
], ],
ids=str, ids=str,
) )
def test_RandomVariable(rv_op, dist_args, size): def test_aligned_RandomVariable(rv_op, dist_args, size):
"""Tests for Numba samplers that are one-to-one with Aesara's/NumPy's samplers."""
rng = shared(np.random.RandomState(29402)) rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_args, size=size, rng=rng) g = rv_op(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
...@@ -3055,6 +2957,149 @@ def test_RandomVariable(rv_op, dist_args, size): ...@@ -3055,6 +2957,149 @@ def test_RandomVariable(rv_op, dist_args, size):
) )
@pytest.mark.parametrize(
"rv_op, dist_args, base_size, cdf_name, params_conv",
[
(
aer.beta,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"beta",
lambda *args: args,
),
(
aer.gamma,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"gamma",
lambda a, b: (a, 0.0, b),
),
(
aer.cauchy,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"cauchy",
lambda *args: args,
),
(
aer.chisquare,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
"chi2",
lambda *args: args,
),
(
aer.gumbel,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"gumbel_r",
lambda *args: args,
),
(
aer.negative_binomial,
[
set_test_value(
at.lvector(),
np.array([100, 200], dtype=np.int64),
),
set_test_value(
at.dscalar(),
np.array(0.09, dtype=np.float64),
),
],
(2,),
"nbinom",
lambda *args: args,
),
pytest.param(
aer.vonmises,
[
set_test_value(
at.dvector(),
np.array([-0.5, 0.5], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"vonmises_line",
lambda mu, kappa: (kappa, mu),
marks=pytest.mark.xfail(
reason=(
"Numba's parameterization of `vonmises` does not match NumPy's."
"See https://github.com/numba/numba/issues/7886"
)
),
),
],
ids=str,
)
def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv):
"""Tests for Numba samplers that are not one-to-one with Aesara's/NumPy's samplers."""
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_args, size=(2000,) + base_size, rng=rng)
g_fn = function(dist_args, g, mode=numba_mode)
samples = g_fn(
*[
i.tag.test_value
for i in g_fn.maker.fgraph.inputs
if not isinstance(i, (SharedVariable, Constant))
]
)
bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_args])
for idx in np.ndindex(*base_size):
cdf_params = params_conv(*tuple(arg[idx] for arg in bcast_dist_args))
test_res = stats.cramervonmises(
samples[(Ellipsis,) + idx], cdf_name, args=cdf_params
)
assert test_res.pvalue > 0.1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"rv_op, dist_args, size, cm", "rv_op, dist_args, size, cm",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论