提交 8794f48d authored 作者: ricardoV94's avatar ricardoV94 提交者: Brandon T. Willard

Allow size to broadcast multivariate distribution parameters

上级 cac73bae
...@@ -331,19 +331,13 @@ class MvNormalRV(RandomVariable): ...@@ -331,19 +331,13 @@ class MvNormalRV(RandomVariable):
# Neither SciPy nor NumPy implement parameter broadcasting for # Neither SciPy nor NumPy implement parameter broadcasting for
# multivariate normals (or any other multivariate distributions), # multivariate normals (or any other multivariate distributions),
# so we need to implement that here # so we need to implement that here
mean, cov = broadcast_params([mean, cov], cls.ndims_params)
size = tuple(size or ())
size = tuple(size or ())
if size: if size:
if (
0 < mean.ndim - 1 <= len(size)
and size[-mean.ndim + 1 :] != mean.shape[:-1]
):
raise ValueError(
"shape mismatch: objects cannot be broadcast to a single shape"
)
mean = np.broadcast_to(mean, size + mean.shape[-1:]) mean = np.broadcast_to(mean, size + mean.shape[-1:])
cov = np.broadcast_to(cov, size + cov.shape[-2:]) cov = np.broadcast_to(cov, size + cov.shape[-2:])
else:
mean, cov = broadcast_params([mean, cov], cls.ndims_params)
res = np.empty(mean.shape) res = np.empty(mean.shape)
for idx in np.ndindex(mean.shape[:-1]): for idx in np.ndindex(mean.shape[:-1]):
...@@ -374,22 +368,12 @@ class DirichletRV(RandomVariable): ...@@ -374,22 +368,12 @@ class DirichletRV(RandomVariable):
size = tuple(np.atleast_1d(size)) size = tuple(np.atleast_1d(size))
if size: if size:
if ( alphas = np.broadcast_to(alphas, size + alphas.shape[-1:])
0 < alphas.ndim - 1 <= len(size)
and size[-alphas.ndim + 1 :] != alphas.shape[:-1]
):
raise ValueError(
"shape mismatch: objects cannot be broadcast to a single shape"
)
samples_shape = size + alphas.shape[-1:]
else:
samples_shape = alphas.shape
samples_shape = alphas.shape
samples = np.empty(samples_shape) samples = np.empty(samples_shape)
alphas_bcast = np.broadcast_to(alphas, samples_shape)
for index in np.ndindex(*samples_shape[:-1]): for index in np.ndindex(*samples_shape[:-1]):
samples[index] = rng.dirichlet(alphas_bcast[index]) samples[index] = rng.dirichlet(alphas[index])
return samples return samples
else: else:
...@@ -608,16 +592,13 @@ class MultinomialRV(RandomVariable): ...@@ -608,16 +592,13 @@ class MultinomialRV(RandomVariable):
@classmethod @classmethod
def rng_fn(cls, rng, n, p, size): def rng_fn(cls, rng, n, p, size):
if n.ndim > 0 or p.ndim > 1: if n.ndim > 0 or p.ndim > 1:
n, p = broadcast_params([n, p], cls.ndims_params)
size = tuple(size or ()) size = tuple(size or ())
if size: if size:
if 0 < p.ndim - 1 <= len(size) and size[-p.ndim + 1 :] != p.shape[:-1]:
raise ValueError(
"shape mismatch: objects cannot be broadcast to a single shape"
)
n = np.broadcast_to(n, size) n = np.broadcast_to(n, size)
p = np.broadcast_to(p, size + p.shape[-1:]) p = np.broadcast_to(p, size + p.shape[-1:])
else:
n, p = broadcast_params([n, p], cls.ndims_params)
res = np.empty(p.shape, dtype=cls.dtype) res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]): for idx in np.ndindex(p.shape[:-1]):
......
...@@ -578,9 +578,9 @@ def test_mvnormal_samples(mu, cov, size): ...@@ -578,9 +578,9 @@ def test_mvnormal_samples(mu, cov, size):
def test_mvnormal_default_args(): def test_mvnormal_default_args():
compare_sample_values(multivariate_normal, test_fn=mvnormal_test_fn) compare_sample_values(multivariate_normal, test_fn=mvnormal_test_fn)
with pytest.raises(ValueError, match="shape mismatch.*"): with pytest.raises(ValueError, match="operands could not be broadcast together "):
multivariate_normal.rng_fn( multivariate_normal.rng_fn(
None, np.zeros((1, 2)), np.ones((1, 2, 2)), size=(4,) None, np.zeros((3, 2)), np.ones((3, 2, 2)), size=(4,)
) )
...@@ -654,11 +654,17 @@ def test_dirichlet_samples(alphas, size): ...@@ -654,11 +654,17 @@ def test_dirichlet_samples(alphas, size):
def test_dirichlet_rng(): def test_dirichlet_rng():
alphas = np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX) alphas = np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)
with pytest.raises(ValueError, match="shape mismatch.*"): with pytest.raises(ValueError, match="operands could not be broadcast together"):
# The independent dimension's shape is missing from size (i.e. should # The independent dimension's shape cannot be broadcasted from (3,) to (10, 2)
# be `(10, 2, 3)`)
dirichlet.rng_fn(None, alphas, size=(10, 2)) dirichlet.rng_fn(None, alphas, size=(10, 2))
with pytest.raises(
ValueError, match="input operand has more dimensions than allowed"
):
# One of the independent dimension's shape is missing from size
# (i.e. should be `(1, 3)`)
dirichlet.rng_fn(None, np.broadcast_to(alphas, (1, 3, 3)), size=(3,))
M_at = iscalar("M") M_at = iscalar("M")
M_at.tag.test_value = 3 M_at.tag.test_value = 3
...@@ -1146,11 +1152,17 @@ def test_multinomial_rng(): ...@@ -1146,11 +1152,17 @@ def test_multinomial_rng():
test_M = np.array([10, 20], dtype=np.int64) test_M = np.array([10, 20], dtype=np.int64)
test_p = np.array([[0.999, 0.001], [0.001, 0.999]], dtype=config.floatX) test_p = np.array([[0.999, 0.001], [0.001, 0.999]], dtype=config.floatX)
with pytest.raises(ValueError, match="shape mismatch.*"): with pytest.raises(ValueError, match="operands could not be broadcast together"):
# The independent dimension's shape is missing from size (i.e. should # The independent dimension's shape cannot be broadcasted from (2,) to (1,)
# be `(1, 2)`)
multinomial.rng_fn(None, test_M, test_p, size=(1,)) multinomial.rng_fn(None, test_M, test_p, size=(1,))
with pytest.raises(
ValueError, match="input operand has more dimensions than allowed"
):
# One of the independent dimension's shape is missing from size
# (i.e. should be `(5, 2)`)
multinomial.rng_fn(None, np.broadcast_to(test_M, (5, 2)), test_p, size=(2,))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"p, size, test_fn", "p, size, test_fn",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论