提交 9ffd222f authored 作者: Larry Dong's avatar Larry Dong 提交者: ricardoV94

Updated ChiSquare r.v. and added tests

上级 4138ac2f
...@@ -117,32 +117,15 @@ class GammaRV(RandomVariable): ...@@ -117,32 +117,15 @@ class GammaRV(RandomVariable):
gamma = GammaRV() gamma = GammaRV()
class ChiSquaredRV(RandomVariable): class ChiSquareRV(RandomVariable):
name = "chi_squared" name = "chisquare"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0] ndims_params = [0]
dtype = "floatX" dtype = "floatX"
_print_name = ("ChiSquared", "\\operatorname{ChiSquared}") _print_name = ("ChiSquare", "\\operatorname{ChiSquare}")
def __call__(
self,
nu: Union[np.ndarray, float],
size: Optional[Union[List[int], int]] = None,
**kwargs
) -> RandomVariable:
return super().call(nu=nu, size=size, **kwargs)
@classmethod
def rng_fn(
cls,
rng: np.random.RandomState,
nu: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
) -> np.ndarray:
return stats.chi2.rvs(nu, size=size, random_state=rng)
chi_squared = ChiSquaredRV() chisquare = ChiSquareRV()
class ParetoRV(RandomVariable): class ParetoRV(RandomVariable):
......
...@@ -19,6 +19,7 @@ from aesara.tensor.random.basic import ( ...@@ -19,6 +19,7 @@ from aesara.tensor.random.basic import (
binomial, binomial,
categorical, categorical,
cauchy, cauchy,
chisquare,
choice, choice,
dirichlet, dirichlet,
exponential, exponential,
...@@ -244,6 +245,13 @@ def test_gamma_samples(): ...@@ -244,6 +245,13 @@ def test_gamma_samples():
rv_numpy_tester(gamma, test_a, test_b, size=[2, 3], test_fn=stats.gamma.rvs) rv_numpy_tester(gamma, test_a, test_b, size=[2, 3], test_fn=stats.gamma.rvs)
def test_chisquare_samples():
test_df = np.array(2, dtype=config.floatX)
rv_numpy_tester(chisquare, test_df, test_fn=stats.chi2.rvs)
rv_numpy_tester(chisquare, test_df, size=[2, 3], test_fn=stats.gamma.rvs)
def test_gumbel_samples(): def test_gumbel_samples():
test_mu = np.array(0.0, dtype=config.floatX) test_mu = np.array(0.0, dtype=config.floatX)
test_beta = np.array(1.0, dtype=config.floatX) test_beta = np.array(1.0, dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论