提交 87751407 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename function and remove unnecessary param for Numba CategoricalRV

上级 27423150
......@@ -319,7 +319,7 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
size_len = int(get_vector_length(node.inputs[1]))
@numba_basic.numba_njit
def sampler(rng, size, dtype, p):
def categorical_rv(rng, size, dtype, p):
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
ind_shape = p.shape[:-1]
......@@ -342,4 +342,4 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
return (rng, res)
return sampler
return categorical_rv
......@@ -3165,10 +3165,9 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
@pytest.mark.parametrize(
"rv_op, dist_args, size, cm",
"dist_args, size, cm",
[
pytest.param(
aer.categorical,
[
set_test_value(
at.dvector(),
......@@ -3179,7 +3178,6 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib.suppress(),
),
pytest.param(
aer.categorical,
[
set_test_value(
at.dmatrix(),
......@@ -3193,7 +3191,6 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib.suppress(),
),
pytest.param(
aer.categorical,
[
set_test_value(
at.dmatrix(),
......@@ -3208,9 +3205,9 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
),
],
)
def test_CategoricalRV(rv_op, dist_args, size, cm):
def test_CategoricalRV(dist_args, size, cm):
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_args, size=size, rng=rng)
g = aer.categorical(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g])
with cm:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论