提交 ce9c17f3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make MultinomialRV broadcast across parameters

上级 45ba2d1f
......@@ -375,6 +375,23 @@ class MultinomialRV(RandomVariable):
self.ndim_supp, dist_params, rep_param_idx, param_shapes
)
@classmethod
def rng_fn(cls, rng, n, p, size):
if n.ndim > 0 or p.ndim > 1:
n, p = broadcast_params([n, p], cls.ndims_params)
size = tuple(size or ())
if size:
n = np.broadcast_to(n, size + n.shape)
p = np.broadcast_to(p, size + p.shape)
res = np.empty(p.shape)
for idx in np.ndindex(p.shape[:-1]):
res[idx] = rng.multinomial(n[idx], p[idx])
return res
else:
return rng.multinomial(n, p, size=size)
multinomial = MultinomialRV()
......
......@@ -6,6 +6,7 @@ import scipy.stats as stats
from pytest import fixture, importorskip, raises
import aesara.tensor as aet
from aesara import shared
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
......@@ -568,6 +569,21 @@ def test_multinomial_samples():
size=[2, 3],
)
rng_state = shared(
np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
)
test_M = np.array([10, 20], dtype="int64")
test_p = np.array([[0.999, 0.001], [0.001, 0.999]], dtype=config.floatX)
res = multinomial(test_M, test_p, rng=rng_state).eval()
exp_res = np.array([[10, 0], [0, 20]])
assert np.array_equal(res, exp_res)
res = multinomial(test_M, test_p, size=(3,), rng=rng_state).eval()
exp_res = np.stack([exp_res] * 3)
assert np.array_equal(res, exp_res)
def test_categorical_samples():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论