提交 2aecb956 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow decomposition methods in MvNormal

上级 2823dfca
......@@ -128,7 +128,6 @@ def jax_sample_fn(op, node):
@jax_sample_fn.register(ptr.BetaRV)
@jax_sample_fn.register(ptr.DirichletRV)
@jax_sample_fn.register(ptr.PoissonRV)
@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_fn_generic(op, node):
"""Generic JAX implementation of random variables."""
name = op.name
......@@ -173,6 +172,20 @@ def jax_sample_fn_loc_scale(op, node):
return sample_fn
@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_mvnormal(op, node):
def sample_fn(rng, size, dtype, mean, cov):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax.random.multivariate_normal(
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method
)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn
@jax_sample_fn.register(ptr.BernoulliRV)
def jax_sample_fn_bernoulli(op, node):
"""JAX implementation of `BernoulliRV`."""
......
......@@ -144,11 +144,24 @@ def core_CategoricalRV(op, node):
@numba_core_rv_funcify.register(ptr.MvNormalRV)
def core_MvNormalRV(op, node):
method = op.method
@numba_basic.numba_njit
def random_fn(rng, mean, cov):
chol = np.linalg.cholesky(cov)
stdnorm = rng.normal(size=cov.shape[-1])
return np.dot(chol, stdnorm) + mean
if method == "cholesky":
A = np.linalg.cholesky(cov)
elif method == "svd":
A, s, _ = np.linalg.svd(cov)
A *= np.sqrt(s)[None, :]
else:
w, A = np.linalg.eigh(cov)
A *= np.sqrt(w)[None, :]
out = rng.normal(size=cov.shape[-1])
# out argument not working correctly: https://github.com/numba/numba/issues/9924
out[:] = np.dot(A, out)
out += mean
return out
random_fn.handles_out = True
return random_fn
......
import abc
import warnings
from typing import Literal
import numpy as np
import scipy.stats as stats
from numpy import broadcast_shapes as np_broadcast_shapes
from numpy import einsum as np_einsum
from numpy import sqrt as np_sqrt
from numpy.linalg import cholesky as np_cholesky
from numpy.linalg import eigh as np_eigh
from numpy.linalg import svd as np_svd
import pytensor
from pytensor.tensor import get_vector_length, specify_shape
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import sqrt
......@@ -852,8 +855,17 @@ class MvNormalRV(RandomVariable):
signature = "(n),(n,n)->(n)"
dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
__props__ = ("name", "signature", "dtype", "inplace", "method")
def __call__(self, mean=None, cov=None, size=None, **kwargs):
def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs):
super().__init__(*args, **kwargs)
if method not in ("cholesky", "svd", "eigh"):
raise ValueError(
f"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'."
)
self.method = method
def __call__(self, mean, cov, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution.
Signature
......@@ -876,33 +888,34 @@ class MvNormalRV(RandomVariable):
is specified, a single `N`-dimensional sample is returned.
"""
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype
if mean is None:
mean = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
return super().__call__(mean, cov, size=size, **kwargs)
@classmethod
def rng_fn(cls, rng, mean, cov, size):
def rng_fn(self, rng, mean, cov, size):
if size is None:
size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
chol = np_cholesky(cov)
if self.method == "cholesky":
A = np_cholesky(cov)
elif self.method == "svd":
A, s, _ = np_svd(cov)
A *= np_sqrt(s, out=s)[..., None, :]
else:
w, A = np_eigh(cov)
A *= np_sqrt(w, out=w)[..., None, :]
out = rng.normal(size=(*size, mean.shape[-1]))
np_einsum(
"...ij,...j->...i", # numpy doesn't have a batch matrix-vector product
chol,
A,
out,
out=out,
optimize=False, # Nothing to optimize with two operands, skip costly setup
out=out,
)
out += mean
return out
multivariate_normal = MvNormalRV()
multivariate_normal = MvNormalRV(method="cholesky")
class DirichletRV(RandomVariable):
......
......@@ -18,6 +18,7 @@ from tests.tensor.random.test_basic import (
batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
create_mvnormal_cov_decomposition_method_test,
)
......@@ -547,6 +548,11 @@ def test_random_mvnormal():
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
"JAX"
)
@pytest.mark.parametrize(
"parameter, size",
[
......
......@@ -22,6 +22,7 @@ from tests.tensor.random.test_basic import (
batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
create_mvnormal_cov_decomposition_method_test,
)
......@@ -147,6 +148,11 @@ def test_multivariate_normal():
)
test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
"NUMBA"
)
@pytest.mark.parametrize(
"rv_op, dist_args, size",
[
......
......@@ -19,6 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import ones, stack
from pytensor.tensor.random.basic import (
ChoiceWithoutReplacement,
MvNormalRV,
PermutationRV,
_gamma,
bernoulli,
......@@ -686,6 +687,49 @@ def test_mvnormal_ShapeFeature():
assert s4.get_test_value() == 3
def create_mvnormal_cov_decomposition_method_test(mode):
@pytest.mark.parametrize("psd", (True, False))
@pytest.mark.parametrize("method", ("cholesky", "svd", "eigh"))
def test_mvnormal_cov_decomposition_method(method, psd):
mean = 2 ** np.arange(3)
if psd:
cov = [
[1, 0.5, -1],
[0.5, 2, 0],
[-1, 0, 3],
]
else:
cov = [
[1, 0.5, 0],
[0.5, 2, 0],
[0, 0, 0],
]
rng = shared(np.random.default_rng(675))
draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,))
assert draws.owner.op.method == method
# JAX doesn't raise errors at runtime
if not psd and method == "cholesky":
if mode == "JAX":
# JAX doesn't raise errors at runtime, instead it returns nan
np.isnan(draws.eval(mode=mode)).all()
else:
with pytest.raises(np.linalg.LinAlgError):
draws.eval(mode=mode)
else:
draws_eval = draws.eval(mode=mode)
np.testing.assert_allclose(np.mean(draws_eval, axis=0), mean, rtol=0.02)
np.testing.assert_allclose(np.cov(draws_eval, rowvar=False), cov, atol=0.1)
return test_mvnormal_cov_decomposition_method
test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
None
)
@pytest.mark.parametrize(
"alphas, size",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论