提交 0c7720cf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a Numba implementation of CategoricalRV

上级 1a5749d4
...@@ -297,3 +297,39 @@ def numba_funcify_BernoulliRV(op, node, **kwargs): ...@@ -297,3 +297,39 @@ def numba_funcify_BernoulliRV(op, node, **kwargs):
body_fn, body_fn,
{"out_dtype": out_dtype, "direct_cast": numba_basic.direct_cast}, {"out_dtype": out_dtype, "direct_cast": numba_basic.direct_cast},
) )
@numba_funcify.register(aer.CategoricalRV)
def numba_funcify_CategoricalRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
ind_shape_len = node.inputs[3].type.ndim - 1
neg_ind_shape_len = -ind_shape_len
size_len = int(get_vector_length(node.inputs[1]))
@numba_basic.numba_njit
def sampler(rng, size, dtype, p):
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
ind_shape = p.shape[:-1]
if ind_shape_len > 0:
if size_len > 0 and size_tpl[neg_ind_shape_len:] != ind_shape:
raise ValueError("Parameters shape and size do not match.")
samples_shape = size_tpl[:neg_ind_shape_len] + ind_shape
p_bcast = np.broadcast_to(p, size_tpl[:neg_ind_shape_len] + p.shape)
else:
samples_shape = size_tpl
p_bcast = p
unif_samples = np.random.uniform(0, 1, samples_shape)
res = np.empty(samples_shape, dtype=out_dtype)
for idx in np.ndindex(*samples_shape):
res[idx] = np.searchsorted(np.cumsum(p_bcast[idx]), unif_samples[idx])
return (rng, res)
return sampler
...@@ -3056,6 +3056,67 @@ def test_RandomVariable(rv_op, dist_args, size): ...@@ -3056,6 +3056,67 @@ def test_RandomVariable(rv_op, dist_args, size):
) )
@pytest.mark.parametrize(
"rv_op, dist_args, size, cm",
[
pytest.param(
aer.categorical,
[
set_test_value(
at.dvector(),
np.array([100000, 1, 1], dtype=np.float64),
),
],
None,
contextlib.suppress(),
),
pytest.param(
aer.categorical,
[
set_test_value(
at.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
dtype=np.float64,
),
),
],
(10, 3),
contextlib.suppress(),
),
pytest.param(
aer.categorical,
[
set_test_value(
at.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
dtype=np.float64,
),
),
],
(10, 4),
pytest.raises(ValueError, match="Parameters shape.*"),
),
],
ids=str,
)
def test_CategoricalRV(rv_op, dist_args, size, cm):
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g])
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
def test_RandomState_updates(): def test_RandomState_updates():
rng = shared(np.random.RandomState(1)) rng = shared(np.random.RandomState(1))
rng_new = shared(np.random.RandomState(2)) rng_new = shared(np.random.RandomState(2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论