提交 75789deb authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement InvGamma and Multinomial in Numba

上级 084bfd75
......@@ -64,7 +64,6 @@ def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
@numba_core_rv_funcify.register(ptr.LaplaceRV)
@numba_core_rv_funcify.register(ptr.BinomialRV)
@numba_core_rv_funcify.register(ptr.NegBinomialRV)
@numba_core_rv_funcify.register(ptr.MultinomialRV)
@numba_core_rv_funcify.register(ptr.PermutationRV)
@numba_core_rv_funcify.register(ptr.IntegersRV)
def numba_core_rv_default(op, node):
......@@ -132,6 +131,15 @@ def numba_core_ParetoRV(op, node):
return random
@numba_core_rv_funcify.register(ptr.InvGammaRV)
def numba_core_InvGammaRV(op, node):
@numba_basic.numba_njit
def random(rng, shape, scale):
return 1 / rng.gamma(shape, 1 / scale)
return random
@numba_core_rv_funcify.register(ptr.CategoricalRV)
def core_CategoricalRV(op, node):
@numba_basic.numba_njit
......@@ -142,6 +150,29 @@ def core_CategoricalRV(op, node):
return random_fn
@numba_core_rv_funcify.register(ptr.MultinomialRV)
def core_MultinomialRV(op, node):
dtype = op.dtype
@numba_basic.numba_njit
def random_fn(rng, n, p):
n_cat = p.shape[0]
draws = np.zeros(n_cat, dtype=dtype)
remaining_p = np.float64(1.0)
remaining_n = n
for i in range(n_cat - 1):
draws[i] = rng.binomial(remaining_n, p[i] / remaining_p)
remaining_n -= draws[i]
if remaining_n <= 0:
break
remaining_p -= p[i]
if remaining_n > 0:
draws[n_cat - 1] = remaining_n
return draws
return random_fn
@numba_core_rv_funcify.register(ptr.MvNormalRV)
def core_MvNormalRV(op, node):
method = op.method
......
......@@ -514,6 +514,31 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
],
(pt.as_tensor([2, 1])),
),
(
ptr.invgamma,
[
(
pt.dvector("shape"),
np.array([1.0, 2.0], dtype=np.float64),
),
(
pt.dvector("scale"),
np.array([0.5, 3.0], dtype=np.float64),
),
],
(2,),
),
(
ptr.multinomial,
[
(
pt.lvector("n"),
np.array([1, 10, 1000], dtype=np.int64),
),
(pt.dvector("p"), np.array([0.3, 0.7], dtype=np.float64)),
],
None,
),
],
ids=str,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论