提交 6b062f0a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix the Numba implementation of NegBinomialRV

`aesara.random.basic.negative_binomial` has also been added as an alias to `aesara.random.basic.nbinom`.
上级 5ea74e97
......@@ -207,7 +207,6 @@ def {sized_fn_name}({random_fn_input_names}):
@numba_funcify.register(aer.WaldRV)
@numba_funcify.register(aer.LaplaceRV)
@numba_funcify.register(aer.BinomialRV)
@numba_funcify.register(aer.NegBinomialRV)
@numba_funcify.register(aer.MultinomialRV)
@numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported
......@@ -270,6 +269,14 @@ def {np_random_fn_name}({np_input_names}):
return make_numba_random_fn(node, np_random_fn)
@numba_funcify.register(aer.NegBinomialRV)
def numba_funcify_NegBinomialRV(op, node, **kwargs):
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np.random.negative_binomial)
@numba_funcify.register(aer.CauchyRV)
def numba_funcify_CauchyRV(op, node, **kwargs):
def body_fn(loc, scale):
......
......@@ -559,6 +559,7 @@ class NegBinomialRV(ScipyRandomVariable):
nbinom = NegBinomialRV()
negative_binomial = NegBinomialRV()
class BetaBinomialRV(ScipyRandomVariable):
......@@ -803,4 +804,5 @@ __all__ = [
"triangular",
"uniform",
"standard_normal",
"negative_binomial",
]
......@@ -2946,21 +2946,20 @@ def test_shared():
],
at.as_tensor([3, 2]),
),
# pytest.param(
# aer.negative_binomial,
# [
# set_test_value(
# at.lvector(),
# np.array([1, 2], dtype=np.int64),
# ),
# set_test_value(
# at.dscalar(),
# np.array(0.9, dtype=np.float64),
# ),
# ],
# at.as_tensor([3, 2]),
# marks=pytest.mark.xfail(reason="Not implemented"),
# ),
(
aer.negative_binomial,
[
set_test_value(
at.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
at.dscalar(),
np.array(0.9, dtype=np.float64),
),
],
at.as_tensor([3, 2]),
),
(
aer.normal,
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论