提交 9e0434bf authored 作者: Kyle Caron's avatar Kyle Caron 提交者: Brandon T. Willard

Implemented Generalized Gamma RV

上级 8794f48d
...@@ -569,6 +569,26 @@ class BetaBinomialRV(ScipyRandomVariable): ...@@ -569,6 +569,26 @@ class BetaBinomialRV(ScipyRandomVariable):
betabinom = BetaBinomialRV() betabinom = BetaBinomialRV()
class GenGammaRV(ScipyRandomVariable):
name = "gengamma"
ndim_supp = 0
ndims_params = [0, 0, 0]
dtype = "floatX"
_print_name = ("GG", "\\operatorname{GG}")
def __call__(self, alpha=1.0, p=1.0, lambd=1.0, size=None, **kwargs):
return super().__call__(alpha, p, lambd, size=size, **kwargs)
@classmethod
def rng_fn_scipy(cls, rng, alpha, p, lambd, size):
return stats.gengamma.rvs(
alpha / p, p, scale=lambd, size=size, random_state=rng
)
gengamma = GenGammaRV()
class MultinomialRV(RandomVariable): class MultinomialRV(RandomVariable):
"""A Multinomial random variable type. """A Multinomial random variable type.
...@@ -794,4 +814,5 @@ __all__ = [ ...@@ -794,4 +814,5 @@ __all__ = [
"uniform", "uniform",
"standard_normal", "standard_normal",
"negative_binomial", "negative_binomial",
"gengamma",
] ]
...@@ -28,6 +28,7 @@ from aesara.tensor.random.basic import ( ...@@ -28,6 +28,7 @@ from aesara.tensor.random.basic import (
dirichlet, dirichlet,
exponential, exponential,
gamma, gamma,
gengamma,
geometric, geometric,
gumbel, gumbel,
halfcauchy, halfcauchy,
...@@ -1093,6 +1094,48 @@ def test_betabinom_samples(M, a, p, size): ...@@ -1093,6 +1094,48 @@ def test_betabinom_samples(M, a, p, size):
) )
@pytest.mark.parametrize(
"alpha, p, lambd, size",
[
(
np.array(2, dtype=config.floatX),
np.array(3, dtype=config.floatX),
np.array(5, dtype=config.floatX),
None,
),
(
np.array(1, dtype=config.floatX),
np.array(1, dtype=config.floatX),
np.array(10, dtype=config.floatX),
[],
),
(
np.array(2, dtype=config.floatX),
np.array(2, dtype=config.floatX),
np.array(10, dtype=config.floatX),
[2, 3],
),
(
np.full((1, 2), 2, dtype=config.floatX),
np.array(2, dtype=config.floatX),
np.array(10, dtype=config.floatX),
None,
),
],
)
def test_gengamma_samples(alpha, p, lambd, size):
compare_sample_values(
gengamma,
alpha,
p,
lambd,
size=size,
test_fn=lambda *args, size=None, random_state=None, **kwargs: gengamma.rng_fn(
random_state, *(args + (size,))
),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, p, size, test_fn", "M, p, size, test_fn",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论