提交 2aecb956 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow decomposition methods in MvNormal

上级 2823dfca
...@@ -128,7 +128,6 @@ def jax_sample_fn(op, node): ...@@ -128,7 +128,6 @@ def jax_sample_fn(op, node):
@jax_sample_fn.register(ptr.BetaRV) @jax_sample_fn.register(ptr.BetaRV)
@jax_sample_fn.register(ptr.DirichletRV) @jax_sample_fn.register(ptr.DirichletRV)
@jax_sample_fn.register(ptr.PoissonRV) @jax_sample_fn.register(ptr.PoissonRV)
@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_fn_generic(op, node): def jax_sample_fn_generic(op, node):
"""Generic JAX implementation of random variables.""" """Generic JAX implementation of random variables."""
name = op.name name = op.name
...@@ -173,6 +172,20 @@ def jax_sample_fn_loc_scale(op, node): ...@@ -173,6 +172,20 @@ def jax_sample_fn_loc_scale(op, node):
return sample_fn return sample_fn
@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_mvnormal(op, node):
def sample_fn(rng, size, dtype, mean, cov):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax.random.multivariate_normal(
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method
)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn
@jax_sample_fn.register(ptr.BernoulliRV) @jax_sample_fn.register(ptr.BernoulliRV)
def jax_sample_fn_bernoulli(op, node): def jax_sample_fn_bernoulli(op, node):
"""JAX implementation of `BernoulliRV`.""" """JAX implementation of `BernoulliRV`."""
......
...@@ -144,11 +144,24 @@ def core_CategoricalRV(op, node): ...@@ -144,11 +144,24 @@ def core_CategoricalRV(op, node):
@numba_core_rv_funcify.register(ptr.MvNormalRV) @numba_core_rv_funcify.register(ptr.MvNormalRV)
def core_MvNormalRV(op, node): def core_MvNormalRV(op, node):
method = op.method
@numba_basic.numba_njit @numba_basic.numba_njit
def random_fn(rng, mean, cov): def random_fn(rng, mean, cov):
chol = np.linalg.cholesky(cov) if method == "cholesky":
stdnorm = rng.normal(size=cov.shape[-1]) A = np.linalg.cholesky(cov)
return np.dot(chol, stdnorm) + mean elif method == "svd":
A, s, _ = np.linalg.svd(cov)
A *= np.sqrt(s)[None, :]
else:
w, A = np.linalg.eigh(cov)
A *= np.sqrt(w)[None, :]
out = rng.normal(size=cov.shape[-1])
# out argument not working correctly: https://github.com/numba/numba/issues/9924
out[:] = np.dot(A, out)
out += mean
return out
random_fn.handles_out = True random_fn.handles_out = True
return random_fn return random_fn
......
import abc import abc
import warnings import warnings
from typing import Literal
import numpy as np import numpy as np
import scipy.stats as stats import scipy.stats as stats
from numpy import broadcast_shapes as np_broadcast_shapes from numpy import broadcast_shapes as np_broadcast_shapes
from numpy import einsum as np_einsum from numpy import einsum as np_einsum
from numpy import sqrt as np_sqrt
from numpy.linalg import cholesky as np_cholesky from numpy.linalg import cholesky as np_cholesky
from numpy.linalg import eigh as np_eigh
from numpy.linalg import svd as np_svd
import pytensor
from pytensor.tensor import get_vector_length, specify_shape from pytensor.tensor import get_vector_length, specify_shape
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import sqrt from pytensor.tensor.math import sqrt
...@@ -852,8 +855,17 @@ class MvNormalRV(RandomVariable): ...@@ -852,8 +855,17 @@ class MvNormalRV(RandomVariable):
signature = "(n),(n,n)->(n)" signature = "(n),(n,n)->(n)"
dtype = "floatX" dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}") _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
__props__ = ("name", "signature", "dtype", "inplace", "method")
def __call__(self, mean=None, cov=None, size=None, **kwargs): def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs):
super().__init__(*args, **kwargs)
if method not in ("cholesky", "svd", "eigh"):
raise ValueError(
f"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'."
)
self.method = method
def __call__(self, mean, cov, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution. r""" "Draw samples from a multivariate normal distribution.
Signature Signature
...@@ -876,33 +888,34 @@ class MvNormalRV(RandomVariable): ...@@ -876,33 +888,34 @@ class MvNormalRV(RandomVariable):
is specified, a single `N`-dimensional sample is returned. is specified, a single `N`-dimensional sample is returned.
""" """
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype
if mean is None:
mean = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
return super().__call__(mean, cov, size=size, **kwargs) return super().__call__(mean, cov, size=size, **kwargs)
@classmethod def rng_fn(self, rng, mean, cov, size):
def rng_fn(cls, rng, mean, cov, size):
if size is None: if size is None:
size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
chol = np_cholesky(cov) if self.method == "cholesky":
A = np_cholesky(cov)
elif self.method == "svd":
A, s, _ = np_svd(cov)
A *= np_sqrt(s, out=s)[..., None, :]
else:
w, A = np_eigh(cov)
A *= np_sqrt(w, out=w)[..., None, :]
out = rng.normal(size=(*size, mean.shape[-1])) out = rng.normal(size=(*size, mean.shape[-1]))
np_einsum( np_einsum(
"...ij,...j->...i", # numpy doesn't have a batch matrix-vector product "...ij,...j->...i", # numpy doesn't have a batch matrix-vector product
chol, A,
out, out,
out=out,
optimize=False, # Nothing to optimize with two operands, skip costly setup optimize=False, # Nothing to optimize with two operands, skip costly setup
out=out,
) )
out += mean out += mean
return out return out
multivariate_normal = MvNormalRV() multivariate_normal = MvNormalRV(method="cholesky")
class DirichletRV(RandomVariable): class DirichletRV(RandomVariable):
......
...@@ -18,6 +18,7 @@ from tests.tensor.random.test_basic import ( ...@@ -18,6 +18,7 @@ from tests.tensor.random.test_basic import (
batched_permutation_tester, batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester, batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester, batched_weighted_choice_without_replacement_tester,
create_mvnormal_cov_decomposition_method_test,
) )
...@@ -547,6 +548,11 @@ def test_random_mvnormal(): ...@@ -547,6 +548,11 @@ def test_random_mvnormal():
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
"JAX"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"parameter, size", "parameter, size",
[ [
......
...@@ -22,6 +22,7 @@ from tests.tensor.random.test_basic import ( ...@@ -22,6 +22,7 @@ from tests.tensor.random.test_basic import (
batched_permutation_tester, batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester, batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester, batched_weighted_choice_without_replacement_tester,
create_mvnormal_cov_decomposition_method_test,
) )
...@@ -147,6 +148,11 @@ def test_multivariate_normal(): ...@@ -147,6 +148,11 @@ def test_multivariate_normal():
) )
test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
"NUMBA"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"rv_op, dist_args, size", "rv_op, dist_args, size",
[ [
......
...@@ -19,6 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -19,6 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import ones, stack from pytensor.tensor import ones, stack
from pytensor.tensor.random.basic import ( from pytensor.tensor.random.basic import (
ChoiceWithoutReplacement, ChoiceWithoutReplacement,
MvNormalRV,
PermutationRV, PermutationRV,
_gamma, _gamma,
bernoulli, bernoulli,
...@@ -686,6 +687,49 @@ def test_mvnormal_ShapeFeature(): ...@@ -686,6 +687,49 @@ def test_mvnormal_ShapeFeature():
assert s4.get_test_value() == 3 assert s4.get_test_value() == 3
def create_mvnormal_cov_decomposition_method_test(mode):
@pytest.mark.parametrize("psd", (True, False))
@pytest.mark.parametrize("method", ("cholesky", "svd", "eigh"))
def test_mvnormal_cov_decomposition_method(method, psd):
mean = 2 ** np.arange(3)
if psd:
cov = [
[1, 0.5, -1],
[0.5, 2, 0],
[-1, 0, 3],
]
else:
cov = [
[1, 0.5, 0],
[0.5, 2, 0],
[0, 0, 0],
]
rng = shared(np.random.default_rng(675))
draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,))
assert draws.owner.op.method == method
# JAX doesn't raise errors at runtime
if not psd and method == "cholesky":
if mode == "JAX":
# JAX doesn't raise errors at runtime, instead it returns nan
np.isnan(draws.eval(mode=mode)).all()
else:
with pytest.raises(np.linalg.LinAlgError):
draws.eval(mode=mode)
else:
draws_eval = draws.eval(mode=mode)
np.testing.assert_allclose(np.mean(draws_eval, axis=0), mean, rtol=0.02)
np.testing.assert_allclose(np.cov(draws_eval, rowvar=False), cov, atol=0.1)
return test_mvnormal_cov_decomposition_method
test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
None
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"alphas, size", "alphas, size",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论