提交 02378861 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix scalar size issue in CategoricalRV

上级 ce9c17f3
...@@ -413,10 +413,14 @@ class CategoricalRV(RandomVariable): ...@@ -413,10 +413,14 @@ class CategoricalRV(RandomVariable):
size = tuple(np.atleast_1d(size)) size = tuple(np.atleast_1d(size))
ind_shape = p.shape[:-1] ind_shape = p.shape[:-1]
if len(size) > 0 and size[-len(ind_shape) :] != ind_shape: if len(ind_shape) > 0:
raise ValueError("Parameters shape and size do not match.") if len(size) > 0 and size[-len(ind_shape) :] != ind_shape:
raise ValueError("Parameters shape and size do not match.")
samples_shape = size[: -len(ind_shape)] + ind_shape
else:
samples_shape = size
samples_shape = size[: -len(ind_shape)] + ind_shape
unif_samples = rng.uniform(size=samples_shape) unif_samples = rng.uniform(size=samples_shape)
samples = vsearchsorted(p.cumsum(axis=-1), unif_samples) samples = vsearchsorted(p.cumsum(axis=-1), unif_samples)
......
...@@ -589,6 +589,10 @@ def test_categorical_samples(): ...@@ -589,6 +589,10 @@ def test_categorical_samples():
rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234))) rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
assert categorical.rng_fn(rng_state, np.array([1.0 / 3.0] * 3), size=10).shape == (
10,
)
p = np.array([[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], dtype=config.floatX) p = np.array([[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], dtype=config.floatX)
p = p / p.sum(axis=-1) p = p / p.sum(axis=-1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论