提交 6b062f0a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix the Numba implementation of NegBinomialRV

`aesara.random.basic.negative_binomial` has also been added as an alias to `aesara.random.basic.nbinom`.
上级 5ea74e97
...@@ -207,7 +207,6 @@ def {sized_fn_name}({random_fn_input_names}): ...@@ -207,7 +207,6 @@ def {sized_fn_name}({random_fn_input_names}):
@numba_funcify.register(aer.WaldRV) @numba_funcify.register(aer.WaldRV)
@numba_funcify.register(aer.LaplaceRV) @numba_funcify.register(aer.LaplaceRV)
@numba_funcify.register(aer.BinomialRV) @numba_funcify.register(aer.BinomialRV)
@numba_funcify.register(aer.NegBinomialRV)
@numba_funcify.register(aer.MultinomialRV) @numba_funcify.register(aer.MultinomialRV)
@numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported @numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported @numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported
...@@ -270,6 +269,14 @@ def {np_random_fn_name}({np_input_names}): ...@@ -270,6 +269,14 @@ def {np_random_fn_name}({np_input_names}):
return make_numba_random_fn(node, np_random_fn) return make_numba_random_fn(node, np_random_fn)
@numba_funcify.register(aer.NegBinomialRV)
def numba_funcify_NegBinomialRV(op, node, **kwargs):
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np.random.negative_binomial)
@numba_funcify.register(aer.CauchyRV) @numba_funcify.register(aer.CauchyRV)
def numba_funcify_CauchyRV(op, node, **kwargs): def numba_funcify_CauchyRV(op, node, **kwargs):
def body_fn(loc, scale): def body_fn(loc, scale):
......
...@@ -559,6 +559,7 @@ class NegBinomialRV(ScipyRandomVariable): ...@@ -559,6 +559,7 @@ class NegBinomialRV(ScipyRandomVariable):
nbinom = NegBinomialRV() nbinom = NegBinomialRV()
negative_binomial = NegBinomialRV()
class BetaBinomialRV(ScipyRandomVariable): class BetaBinomialRV(ScipyRandomVariable):
...@@ -803,4 +804,5 @@ __all__ = [ ...@@ -803,4 +804,5 @@ __all__ = [
"triangular", "triangular",
"uniform", "uniform",
"standard_normal", "standard_normal",
"negative_binomial",
] ]
...@@ -2946,21 +2946,20 @@ def test_shared(): ...@@ -2946,21 +2946,20 @@ def test_shared():
], ],
at.as_tensor([3, 2]), at.as_tensor([3, 2]),
), ),
# pytest.param( (
# aer.negative_binomial, aer.negative_binomial,
# [ [
# set_test_value( set_test_value(
# at.lvector(), at.lvector(),
# np.array([1, 2], dtype=np.int64), np.array([1, 2], dtype=np.int64),
# ), ),
# set_test_value( set_test_value(
# at.dscalar(), at.dscalar(),
# np.array(0.9, dtype=np.float64), np.array(0.9, dtype=np.float64),
# ), ),
# ], ],
# at.as_tensor([3, 2]), at.as_tensor([3, 2]),
# marks=pytest.mark.xfail(reason="Not implemented"), ),
# ),
( (
aer.normal, aer.normal,
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论