Unverified 提交 b852bd24 authored 作者: Kaustubh's avatar Kaustubh 提交者: GitHub

Added Laplace and Wald RV (#321)

* Added TuncatedNormal, Laplace, Wald and Kumaraswamy RV * Removed non-numpy related RVs
上级 70ed6548
......@@ -258,6 +258,20 @@ class InvGammaRV(RandomVariable):
invgamma = InvGammaRV()
class WaldRV(RandomVariable):
name = "wald"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name_ = ("Wald", "\\operatorname{Wald}")
def __call__(self, mean=1.0, scale=1.0, size=None, **kwargs):
return super().__call__(mean, scale, size=size, **kwargs)
wald = WaldRV()
class TruncExponentialRV(RandomVariable):
name = "truncexpon"
ndim_supp = 0
......@@ -290,6 +304,17 @@ class BernoulliRV(RandomVariable):
bernoulli = BernoulliRV()
class LaplaceRV(RandomVariable):
name = "laplace"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("Laplace", "\\operatorname{Laplace}")
laplace = LaplaceRV()
class BinomialRV(RandomVariable):
name = "binomial"
ndim_supp = 0
......
......@@ -25,6 +25,7 @@ from aesara.tensor.random.basic import (
halfcauchy,
halfnormal,
invgamma,
laplace,
multinomial,
multivariate_normal,
nbinom,
......@@ -35,6 +36,7 @@ from aesara.tensor.random.basic import (
randint,
truncexpon,
uniform,
wald,
)
from aesara.tensor.type import iscalar, scalar, tensor
......@@ -464,6 +466,14 @@ def test_invgamma_samples():
)
def test_wald_samples():
test_mean = np.array(10, dtype=config.floatX)
test_scale = np.array(1, dtype=config.floatX)
rv_numpy_tester(wald, test_mean, test_scale)
rv_numpy_tester(wald, test_mean, test_scale, size=[2, 3])
def test_truncexpon_samples():
test_b = np.array(5, dtype=config.floatX)
test_loc = np.array(0, dtype=config.floatX)
......@@ -498,6 +508,14 @@ def test_bernoulli_samples():
)
def test_laplace_samples():
test_loc = np.array(10, dtype=config.floatX)
test_scale = np.array(5, dtype=config.floatX)
rv_numpy_tester(laplace, test_loc, test_scale)
rv_numpy_tester(laplace, test_loc, test_scale, size=[2, 3])
def test_binomial_samples():
test_M = np.array(10, dtype="int64")
test_p = np.array(0.5, dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论