提交 06c17926 authored 作者: Chris Fonnesbeck's avatar Chris Fonnesbeck 提交者: Brandon T. Willard

Added Gumbel RV

上级 16ee436d
......@@ -46,3 +46,4 @@ core
aesara-venv/
/notebooks/Sandbox*
.vscode/
from typing import List, Optional, Union
import numpy as np
import scipy.stats as stats
......@@ -108,6 +110,36 @@ class ParetoRV(RandomVariable):
pareto = ParetoRV()
class GumbelRV(RandomVariable):
name = "gumbel"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("Gumbel", "\\operatorname{Gumbel}")
def __call__(
self,
loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float] = 1.0,
size: Optional[Union[List[int], int]] = None,
**kwargs
) -> RandomVariable:
return super().__call__(loc, scale, size=size, **kwargs)
@classmethod
def rng_fn(
cls,
rng: np.random.RandomState,
loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
) -> np.ndarray:
return stats.gumbel_r.rvs(loc=loc, scale=scale, size=size, random_state=rng)
gumbel = GumbelRV()
class ExponentialRV(RandomVariable):
name = "exponential"
ndim_supp = 0
......
......@@ -23,6 +23,7 @@ from aesara.tensor.random.basic import (
dirichlet,
exponential,
gamma,
gumbel,
halfcauchy,
halfnormal,
invgamma,
......@@ -218,6 +219,14 @@ def test_gamma_samples():
rv_numpy_tester(gamma, test_a, test_b, size=[2, 3], test_fn=stats.gamma.rvs)
def test_gumbel_samples():
test_mu = np.array(0.0, dtype=config.floatX)
test_beta = np.array(1.0, dtype=config.floatX)
rv_numpy_tester(gumbel, test_mu, test_beta, test_fn=stats.gumbel_r.rvs)
rv_numpy_tester(gumbel, test_mu, test_beta, size=[2, 3], test_fn=stats.gumbel_r.rvs)
def test_exponential_samples():
rv_numpy_tester(exponential)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论