提交 87751407 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename function and remove unnecessary param for Numba CategoricalRV

上级 27423150
...@@ -319,7 +319,7 @@ def numba_funcify_CategoricalRV(op, node, **kwargs): ...@@ -319,7 +319,7 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
size_len = int(get_vector_length(node.inputs[1])) size_len = int(get_vector_length(node.inputs[1]))
@numba_basic.numba_njit @numba_basic.numba_njit
def sampler(rng, size, dtype, p): def categorical_rv(rng, size, dtype, p):
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
ind_shape = p.shape[:-1] ind_shape = p.shape[:-1]
...@@ -342,4 +342,4 @@ def numba_funcify_CategoricalRV(op, node, **kwargs): ...@@ -342,4 +342,4 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
return (rng, res) return (rng, res)
return sampler return categorical_rv
...@@ -3165,10 +3165,9 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -3165,10 +3165,9 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
@pytest.mark.parametrize( @pytest.mark.parametrize(
"rv_op, dist_args, size, cm", "dist_args, size, cm",
[ [
pytest.param( pytest.param(
aer.categorical,
[ [
set_test_value( set_test_value(
at.dvector(), at.dvector(),
...@@ -3179,7 +3178,6 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -3179,7 +3178,6 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib.suppress(), contextlib.suppress(),
), ),
pytest.param( pytest.param(
aer.categorical,
[ [
set_test_value( set_test_value(
at.dmatrix(), at.dmatrix(),
...@@ -3193,7 +3191,6 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -3193,7 +3191,6 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib.suppress(), contextlib.suppress(),
), ),
pytest.param( pytest.param(
aer.categorical,
[ [
set_test_value( set_test_value(
at.dmatrix(), at.dmatrix(),
...@@ -3208,9 +3205,9 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -3208,9 +3205,9 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
), ),
], ],
) )
def test_CategoricalRV(rv_op, dist_args, size, cm): def test_CategoricalRV(dist_args, size, cm):
rng = shared(np.random.RandomState(29402)) rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_args, size=size, rng=rng) g = aer.categorical(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
with cm: with cm:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论