提交 2823dfca authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster python implementation of MvNormal

Also remove bad default values
上级 1ed36119
......@@ -3,6 +3,9 @@ import warnings
import numpy as np
import scipy.stats as stats
from numpy import broadcast_shapes as np_broadcast_shapes
from numpy import einsum as np_einsum
from numpy.linalg import cholesky as np_cholesky
import pytensor
from pytensor.tensor import get_vector_length, specify_shape
......@@ -831,27 +834,6 @@ class VonMisesRV(RandomVariable):
vonmises = VonMisesRV()
def safe_multivariate_normal(mean, cov, size=None, rng=None):
"""A shape consistent multivariate normal sampler.
What we mean by "shape consistent": SciPy will return scalars when the
arguments are vectors with dimension of size 1. We require that the output
be at least 1D, so that it's consistent with the underlying random
variable.
"""
res = np.atleast_1d(
stats.multivariate_normal(mean=mean, cov=cov, allow_singular=True).rvs(
size=size, random_state=rng
)
)
if size is not None:
res = res.reshape([*size, -1])
return res
class MvNormalRV(RandomVariable):
r"""A multivariate normal random variable.
......@@ -904,25 +886,20 @@ class MvNormalRV(RandomVariable):
@classmethod
def rng_fn(cls, rng, mean, cov, size):
if mean.ndim > 1 or cov.ndim > 2:
# Neither SciPy nor NumPy implement parameter broadcasting for
# multivariate normals (or any other multivariate distributions),
# so we need to implement that here
if size is None:
mean, cov = broadcast_params([mean, cov], [1, 2])
else:
mean = np.broadcast_to(mean, size + mean.shape[-1:])
cov = np.broadcast_to(cov, size + cov.shape[-2:])
res = np.empty(mean.shape)
for idx in np.ndindex(mean.shape[:-1]):
m = mean[idx]
c = cov[idx]
res[idx] = safe_multivariate_normal(m, c, rng=rng)
return res
else:
return safe_multivariate_normal(mean, cov, size=size, rng=rng)
if size is None:
size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
chol = np_cholesky(cov)
out = rng.normal(size=(*size, mean.shape[-1]))
np_einsum(
"...ij,...j->...i", # numpy doesn't have a batch matrix-vector product
chol,
out,
out=out,
optimize=False, # Nothing to optimize with two operands, skip costly setup
)
out += mean
return out
multivariate_normal = MvNormalRV()
......
......@@ -778,8 +778,10 @@ def rand_bool_mask(shape, rng=None):
multivariate_normal,
(
np.array([200, 250], dtype=config.floatX),
# Second covariance is invalid, to test it is not chosen
np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX)
# Second covariance is very large, to test it is not chosen
np.dstack([np.eye(2), np.eye(2) * 1000, np.eye(2)]).T.astype(
config.floatX
)
* 1e-6,
),
(3,),
......
......@@ -521,13 +521,19 @@ def test_pareto_samples(alpha, scale, size):
def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
if mean is None:
mean = np.array([0.0], dtype=config.floatX)
if cov is None:
cov = np.array([[1.0]], dtype=config.floatX)
if size is not None:
size = tuple(size)
return multivariate_normal.rng_fn(random_state, mean, cov, size)
rng = random_state if random_state is not None else np.random.default_rng()
if size is None:
size = np.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
mean = np.broadcast_to(mean, (*size, *mean.shape[-1:]))
cov = np.broadcast_to(cov, (*size, *cov.shape[-2:]))
@np.vectorize(signature="(n),(n,n)->(n)")
def vec_mvnormal(mean, cov):
return rng.multivariate_normal(mean, cov, method="cholesky")
return vec_mvnormal(mean, cov)
@pytest.mark.parametrize(
......@@ -609,18 +615,30 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
),
],
)
@pytest.mark.skipif(
config.floatX == "float32",
reason="Draws are only strictly equal to numpy in float64",
)
def test_mvnormal_samples(mu, cov, size):
compare_sample_values(
multivariate_normal, mu, cov, size=size, test_fn=mvnormal_test_fn
)
def test_mvnormal_default_args():
compare_sample_values(multivariate_normal, test_fn=mvnormal_test_fn)
def test_mvnormal_no_default_args():
with pytest.raises(
TypeError, match="missing 2 required positional arguments: 'mean' and 'cov'"
):
multivariate_normal()
def test_mvnormal_impl_catches_incompatible_size():
with pytest.raises(ValueError, match="operands could not be broadcast together "):
multivariate_normal.rng_fn(
None, np.zeros((3, 2)), np.ones((3, 2, 2)), size=(4,)
np.random.default_rng(),
np.zeros((3, 2)),
np.broadcast_to(np.eye(2), (3, 2, 2)),
size=(4,),
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论