提交 081967d3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Numba implementation for DirichletRV

上级 87751407
...@@ -343,3 +343,45 @@ def numba_funcify_CategoricalRV(op, node, **kwargs): ...@@ -343,3 +343,45 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
return (rng, res) return (rng, res)
return categorical_rv return categorical_rv
@numba_funcify.register(aer.DirichletRV)
def numba_funcify_DirichletRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
alphas_ndim = node.inputs[3].type.ndim
neg_ind_shape_len = -alphas_ndim + 1
size_len = int(get_vector_length(node.inputs[1]))
if alphas_ndim > 1:
@numba_basic.numba_njit
def dirichlet_rv(rng, size, dtype, alphas):
if size_len > 0:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
if (
0 < alphas.ndim - 1 <= len(size_tpl)
and size_tpl[neg_ind_shape_len:] != alphas.shape[:-1]
):
raise ValueError("Parameters shape and size do not match.")
samples_shape = size_tpl + alphas.shape[-1:]
else:
samples_shape = alphas.shape
res = np.empty(samples_shape, dtype=out_dtype)
alphas_bcast = np.broadcast_to(alphas, samples_shape)
for index in np.ndindex(*samples_shape[:-1]):
res[index] = np.random.dirichlet(alphas_bcast[index])
return (rng, res)
else:
@numba_basic.numba_njit
def dirichlet_rv(rng, size, dtype, alphas):
size = numba_ndarray.to_fixed_tuple(size, size_len)
return (rng, np.random.dirichlet(alphas, size))
return dirichlet_rv
...@@ -3221,6 +3221,62 @@ def test_CategoricalRV(dist_args, size, cm): ...@@ -3221,6 +3221,62 @@ def test_CategoricalRV(dist_args, size, cm):
) )
@pytest.mark.parametrize(
"a, size, cm",
[
pytest.param(
set_test_value(
at.dvector(),
np.array([100000, 1, 1], dtype=np.float64),
),
None,
contextlib.suppress(),
),
pytest.param(
set_test_value(
at.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
dtype=np.float64,
),
),
(10, 3),
contextlib.suppress(),
),
pytest.param(
set_test_value(
at.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
dtype=np.float64,
),
),
(10, 4),
pytest.raises(ValueError, match="Parameters shape.*"),
),
],
)
def test_DirichletRV(a, size, cm):
rng = shared(np.random.RandomState(29402))
g = aer.dirichlet(a, size=size, rng=rng)
g_fn = function([a], g, mode=numba_mode)
with cm:
a_val = a.tag.test_value
# For coverage purposes only...
eval_python_only([a], FunctionGraph(outputs=[g], clone=False), [a_val])
all_samples = []
for i in range(1000):
samples = g_fn(a_val)
all_samples.append(samples)
exp_res = a_val / a_val.sum(-1)
res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1)))
assert np.allclose(res, exp_res, atol=1e-4)
def test_RandomState_updates(): def test_RandomState_updates():
rng = shared(np.random.RandomState(1)) rng = shared(np.random.RandomState(1))
rng_new = shared(np.random.RandomState(2)) rng_new = shared(np.random.RandomState(2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论