提交 ce9c17f3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make MultinomialRV broadcast across parameters

上级 45ba2d1f
...@@ -375,6 +375,23 @@ class MultinomialRV(RandomVariable): ...@@ -375,6 +375,23 @@ class MultinomialRV(RandomVariable):
self.ndim_supp, dist_params, rep_param_idx, param_shapes self.ndim_supp, dist_params, rep_param_idx, param_shapes
) )
@classmethod
def rng_fn(cls, rng, n, p, size):
if n.ndim > 0 or p.ndim > 1:
n, p = broadcast_params([n, p], cls.ndims_params)
size = tuple(size or ())
if size:
n = np.broadcast_to(n, size + n.shape)
p = np.broadcast_to(p, size + p.shape)
res = np.empty(p.shape)
for idx in np.ndindex(p.shape[:-1]):
res[idx] = rng.multinomial(n[idx], p[idx])
return res
else:
return rng.multinomial(n, p, size=size)
multinomial = MultinomialRV() multinomial = MultinomialRV()
......
...@@ -6,6 +6,7 @@ import scipy.stats as stats ...@@ -6,6 +6,7 @@ import scipy.stats as stats
from pytest import fixture, importorskip, raises from pytest import fixture, importorskip, raises
import aesara.tensor as aet import aesara.tensor as aet
from aesara import shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, graph_inputs from aesara.graph.basic import Constant, Variable, graph_inputs
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -568,6 +569,21 @@ def test_multinomial_samples(): ...@@ -568,6 +569,21 @@ def test_multinomial_samples():
size=[2, 3], size=[2, 3],
) )
rng_state = shared(
np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
)
test_M = np.array([10, 20], dtype="int64")
test_p = np.array([[0.999, 0.001], [0.001, 0.999]], dtype=config.floatX)
res = multinomial(test_M, test_p, rng=rng_state).eval()
exp_res = np.array([[10, 0], [0, 20]])
assert np.array_equal(res, exp_res)
res = multinomial(test_M, test_p, size=(3,), rng=rng_state).eval()
exp_res = np.stack([exp_res] * 3)
assert np.array_equal(res, exp_res)
def test_categorical_samples(): def test_categorical_samples():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论