Unverified 提交 b852bd24 authored 作者: Kaustubh's avatar Kaustubh 提交者: GitHub

Added Laplace and Wald RV (#321)

* Added TuncatedNormal, Laplace, Wald and Kumaraswamy RV * Removed non-numpy related RVs
上级 70ed6548
...@@ -258,6 +258,20 @@ class InvGammaRV(RandomVariable): ...@@ -258,6 +258,20 @@ class InvGammaRV(RandomVariable):
invgamma = InvGammaRV() invgamma = InvGammaRV()
class WaldRV(RandomVariable):
name = "wald"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name_ = ("Wald", "\\operatorname{Wald}")
def __call__(self, mean=1.0, scale=1.0, size=None, **kwargs):
return super().__call__(mean, scale, size=size, **kwargs)
wald = WaldRV()
class TruncExponentialRV(RandomVariable): class TruncExponentialRV(RandomVariable):
name = "truncexpon" name = "truncexpon"
ndim_supp = 0 ndim_supp = 0
...@@ -290,6 +304,17 @@ class BernoulliRV(RandomVariable): ...@@ -290,6 +304,17 @@ class BernoulliRV(RandomVariable):
bernoulli = BernoulliRV() bernoulli = BernoulliRV()
class LaplaceRV(RandomVariable):
name = "laplace"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("Laplace", "\\operatorname{Laplace}")
laplace = LaplaceRV()
class BinomialRV(RandomVariable): class BinomialRV(RandomVariable):
name = "binomial" name = "binomial"
ndim_supp = 0 ndim_supp = 0
......
...@@ -25,6 +25,7 @@ from aesara.tensor.random.basic import ( ...@@ -25,6 +25,7 @@ from aesara.tensor.random.basic import (
halfcauchy, halfcauchy,
halfnormal, halfnormal,
invgamma, invgamma,
laplace,
multinomial, multinomial,
multivariate_normal, multivariate_normal,
nbinom, nbinom,
...@@ -35,6 +36,7 @@ from aesara.tensor.random.basic import ( ...@@ -35,6 +36,7 @@ from aesara.tensor.random.basic import (
randint, randint,
truncexpon, truncexpon,
uniform, uniform,
wald,
) )
from aesara.tensor.type import iscalar, scalar, tensor from aesara.tensor.type import iscalar, scalar, tensor
...@@ -464,6 +466,14 @@ def test_invgamma_samples(): ...@@ -464,6 +466,14 @@ def test_invgamma_samples():
) )
def test_wald_samples():
test_mean = np.array(10, dtype=config.floatX)
test_scale = np.array(1, dtype=config.floatX)
rv_numpy_tester(wald, test_mean, test_scale)
rv_numpy_tester(wald, test_mean, test_scale, size=[2, 3])
def test_truncexpon_samples(): def test_truncexpon_samples():
test_b = np.array(5, dtype=config.floatX) test_b = np.array(5, dtype=config.floatX)
test_loc = np.array(0, dtype=config.floatX) test_loc = np.array(0, dtype=config.floatX)
...@@ -498,6 +508,14 @@ def test_bernoulli_samples(): ...@@ -498,6 +508,14 @@ def test_bernoulli_samples():
) )
def test_laplace_samples():
test_loc = np.array(10, dtype=config.floatX)
test_scale = np.array(5, dtype=config.floatX)
rv_numpy_tester(laplace, test_loc, test_scale)
rv_numpy_tester(laplace, test_loc, test_scale, size=[2, 3])
def test_binomial_samples(): def test_binomial_samples():
test_M = np.array(10, dtype="int64") test_M = np.array(10, dtype="int64")
test_p = np.array(0.5, dtype=config.floatX) test_p = np.array(0.5, dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论