提交 3d796783 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Add docstring for `MvNormalRV`

上级 e191ed97
...@@ -761,6 +761,18 @@ def safe_multivariate_normal(mean, cov, size=None, rng=None): ...@@ -761,6 +761,18 @@ def safe_multivariate_normal(mean, cov, size=None, rng=None):
class MvNormalRV(RandomVariable): class MvNormalRV(RandomVariable):
r"""A multivariate normal random variable.
The probability density function for `multivariate_normal` in term of its location parameter
:math:`\boldsymbol{\mu}` and covariance matrix :math:`\Sigma` is
.. math::
f(\boldsymbol{x}; \boldsymbol{\mu}, \Sigma) = \det(2 \pi \Sigma)^{-1/2} \exp\left(-\frac{1}{2} (\boldsymbol{x} - \boldsymbol{\mu})^T \Sigma (\boldsymbol{x} - \boldsymbol{\mu})\right)
where :math:`\Sigma` is a positive semi-definite matrix.
"""
name = "multivariate_normal" name = "multivariate_normal"
ndim_supp = 1 ndim_supp = 1
ndims_params = [1, 2] ndims_params = [1, 2]
...@@ -768,7 +780,23 @@ class MvNormalRV(RandomVariable): ...@@ -768,7 +780,23 @@ class MvNormalRV(RandomVariable):
_print_name = ("N", "\\operatorname{N}") _print_name = ("N", "\\operatorname{N}")
def __call__(self, mean=None, cov=None, size=None, **kwargs): def __call__(self, mean=None, cov=None, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution.
Parameters
----------
mean
Location parameter (mean) :math:`\boldsymbol{\mu}` of the distribution. Vector
of length `N`.
cov
Covariance matrix :math:`\Sigma` of the distribution. Must be a symmetric
and positive-semidefinite `NxN` matrix.
size
Given a size of, for example, `(m, n, k)`, `m * n * k` independent,
identically distributed samples are generated. Because each sample
is `N`-dimensional, the output shape is `(m, n, k, N)`. If no shape
is specified, a single `N`-dimensional sample is returned.
"""
dtype = aesara.config.floatX if self.dtype == "floatX" else self.dtype dtype = aesara.config.floatX if self.dtype == "floatX" else self.dtype
if mean is None: if mean is None:
......
...@@ -106,6 +106,9 @@ Aesara can produce :class:`RandomVariable`\s that draw samples from many differe ...@@ -106,6 +106,9 @@ Aesara can produce :class:`RandomVariable`\s that draw samples from many differe
.. autoclass:: aesara.tensor.random.basic.LogNormalRV .. autoclass:: aesara.tensor.random.basic.LogNormalRV
:members: __call__ :members: __call__
.. autoclass:: aesara.tensor.random.basic.MvNormalRV
:members: __call__
.. autoclass:: aesara.tensor.random.basic.NegBinomialRV .. autoclass:: aesara.tensor.random.basic.NegBinomialRV
:members: __call__ :members: __call__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论