提交 fd50f36b authored 作者: ricardoV94's avatar ricardoV94 提交者: Brandon T. Willard

Allow size to broadcast categorical p argument

上级 58b38f93
...@@ -312,33 +312,21 @@ def numba_funcify_BernoulliRV(op, node, **kwargs): ...@@ -312,33 +312,21 @@ def numba_funcify_BernoulliRV(op, node, **kwargs):
@numba_funcify.register(aer.CategoricalRV) @numba_funcify.register(aer.CategoricalRV)
def numba_funcify_CategoricalRV(op, node, **kwargs): def numba_funcify_CategoricalRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype out_dtype = node.outputs[1].type.numpy_dtype
ind_shape_len = node.inputs[3].type.ndim - 1
neg_ind_shape_len = -ind_shape_len
size_len = int(get_vector_length(node.inputs[1])) size_len = int(get_vector_length(node.inputs[1]))
@numba_basic.numba_njit @numba_basic.numba_njit
def categorical_rv(rng, size, dtype, p): def categorical_rv(rng, size, dtype, p):
if not size_len:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) size_tpl = p.shape[:-1]
ind_shape = p.shape[:-1]
if ind_shape_len > 0:
if size_len > 0 and size_tpl[neg_ind_shape_len:] != ind_shape:
raise ValueError("Parameters shape and size do not match.")
samples_shape = size_tpl[:neg_ind_shape_len] + ind_shape
p_bcast = np.broadcast_to(p, size_tpl[:neg_ind_shape_len] + p.shape)
else: else:
samples_shape = size_tpl size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
p_bcast = p p = np.broadcast_to(p, size_tpl + p.shape[-1:])
unif_samples = np.random.uniform(0, 1, samples_shape) unif_samples = np.random.uniform(0, 1, size_tpl)
res = np.empty(samples_shape, dtype=out_dtype) res = np.empty(size_tpl, dtype=out_dtype)
for idx in np.ndindex(*samples_shape): for idx in np.ndindex(*size_tpl):
res[idx] = np.searchsorted(np.cumsum(p_bcast[idx]), unif_samples[idx]) res[idx] = np.searchsorted(np.cumsum(p[idx]), unif_samples[idx])
return (rng, res) return (rng, res)
......
...@@ -1214,20 +1214,17 @@ class CategoricalRV(RandomVariable): ...@@ -1214,20 +1214,17 @@ class CategoricalRV(RandomVariable):
@classmethod @classmethod
def rng_fn(cls, rng, p, size): def rng_fn(cls, rng, p, size):
if size is None: if size is None:
size = () size = p.shape[:-1]
size = tuple(np.atleast_1d(size))
ind_shape = p.shape[:-1]
if len(ind_shape) > 0:
if len(size) > 0 and size[-len(ind_shape) :] != ind_shape:
raise ValueError("Parameters shape and size do not match.")
samples_shape = size[: -len(ind_shape)] + ind_shape
else: else:
samples_shape = size # Check that `size` does not define a shape that would be broadcasted
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
unif_samples = rng.uniform(size=samples_shape) if len(size) < (p.ndim - 1):
raise ValueError("`size` is incompatible with the shape of `p`")
for s, ps in zip(reversed(size), reversed(p.shape[:-1])):
if s == 1 and ps != 1:
raise ValueError("`size` is incompatible with the shape of `p`")
unif_samples = rng.uniform(size=size)
samples = vsearchsorted(p.cumsum(axis=-1), unif_samples) samples = vsearchsorted(p.cumsum(axis=-1), unif_samples)
return samples return samples
......
...@@ -3216,6 +3216,19 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -3216,6 +3216,19 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
(10, 3), (10, 3),
contextlib.suppress(), contextlib.suppress(),
), ),
pytest.param(
[
set_test_value(
at.dmatrix(),
np.array(
[[100000, 1, 1]],
dtype=np.float64,
),
),
],
(5, 4, 3),
contextlib.suppress(),
),
pytest.param( pytest.param(
[ [
set_test_value( set_test_value(
...@@ -3227,7 +3240,9 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -3227,7 +3240,9 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
), ),
], ],
(10, 4), (10, 4),
pytest.raises(ValueError, match="Parameters shape.*"), pytest.raises(
ValueError, match="objects cannot be broadcast to a single shape"
),
), ),
], ],
) )
......
import pickle import pickle
import re
from copy import copy from copy import copy
import numpy as np import numpy as np
...@@ -1229,10 +1230,15 @@ def test_multinomial_rng(): ...@@ -1229,10 +1230,15 @@ def test_multinomial_rng():
(10, 2, 3), (10, 2, 3),
lambda *args, **kwargs: np.tile(np.arange(3).astype(np.int64), (10, 2, 1)), lambda *args, **kwargs: np.tile(np.arange(3).astype(np.int64), (10, 2, 1)),
), ),
(
np.full((4, 1, 3), [100000, 1, 1], dtype=config.floatX),
(4, 2),
lambda *args, **kwargs: np.zeros((4, 2), dtype=np.int64),
),
], ],
) )
def test_categorical_samples(p, size, test_fn): def test_categorical_samples(p, size, test_fn):
p = p / p.sum(axis=-1) p = p / p.sum(axis=-1, keepdims=True)
rng = np.random.default_rng(232) rng = np.random.default_rng(232)
compare_sample_values( compare_sample_values(
...@@ -1251,7 +1257,20 @@ def test_categorical_basic(): ...@@ -1251,7 +1257,20 @@ def test_categorical_basic():
rng = np.random.default_rng() rng = np.random.default_rng()
with pytest.raises(ValueError): with pytest.raises(ValueError):
categorical.rng_fn(rng, p, size=10) # The independent dimension of p has shape=(3,) which cannot be
# broadcasted to (10,)
categorical.rng_fn(rng, p, size=(10,))
msg = re.escape("`size` is incompatible with the shape of `p`")
with pytest.raises(ValueError, match=msg):
# The independent dimension of p has shape=(3,) which cannot be
# broadcasted to (1,)
categorical.rng_fn(rng, p, size=(1,))
with pytest.raises(ValueError, match=msg):
# The independent dimensions of p have shape=(1, 3) which cannot be
# broadcasted to (3,)
categorical.rng_fn(rng, p[None], size=(3,))
def test_randint_samples(): def test_randint_samples():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论