提交 78706bd8 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Make `ChoiceRV` behave like NumPy's `choice`

`ChoiceRV` currently does not behave like its NumPy equivalent when its `a` parameters is an `int` or an array with more than one dimension. We implement the necessary changes so that it does.
上级 d7fb9402
......@@ -1759,7 +1759,14 @@ class ChoiceRV(RandomVariable):
raise NotImplementedError()
def _infer_shape(self, size, dist_params, param_shapes=None):
return size
(a, p, _) = dist_params
if isinstance(p.type, aesara.tensor.type_other.NoneTypeT):
shape = super()._infer_shape(size, (a,), param_shapes)
else:
shape = super()._infer_shape(size, (a, p), param_shapes)
return shape
def __call__(self, a, size=None, replace=True, p=None, **kwargs):
r"""Generate a random sample from an array.
......@@ -1767,7 +1774,8 @@ class ChoiceRV(RandomVariable):
Parameters
----------
a
The array from which to randomly sample an element.
The array from which to randomly sample an element. If an int,
a sample is generated from `aesara.tensor.arange(a)`.
size
Sample shape. If the given size is `(m, n, k)`, then `m * n *
k` independent samples are returned. Default is `None`, in
......@@ -1778,7 +1786,10 @@ class ChoiceRV(RandomVariable):
The probabilities associated with each entry in `a`. If not
given, all elements have equal probability.
"""
a = as_tensor_variable(a, ndim=1)
a = as_tensor_variable(a)
if a.ndim == 0:
a = aesara.tensor.arange(a)
if p is None:
p = aesara.tensor.type_other.NoneConst.clone()
......
......@@ -1321,17 +1321,28 @@ def test_choice_samples():
with pytest.raises(NotImplementedError):
choice._supp_shape_from_params(np.asarray(5))
compare_sample_values(choice, np.asarray(5))
compare_sample_values(choice, np.asarray([5]))
compare_sample_values(choice, np.array([1.0, 5.0], dtype=config.floatX))
compare_sample_values(choice, np.asarray([5]), 3)
with pytest.raises(ValueError):
compare_sample_values(choice, np.array([[1, 2], [3, 4]]))
compare_sample_values(choice, np.array([[1, 2], [3, 4]]), p=[0.4, 0.6])
compare_sample_values(choice, [1, 2, 3], 1)
compare_sample_values(
choice, [1, 2, 3], 1, p=at.as_tensor([1 / 3.0, 1 / 3.0, 1 / 3.0])
)
# p must be 1-dimensional.
# TODO: The exception is raised at runtime but could be raised at compile
# time in some situations using static shape analysis.
with pytest.raises(ValueError):
rng = np.random.default_rng()
rng_at = shared(rng, borrow=True)
choice(a=[1, 2], p=at.as_tensor([[0.1, 0.9], [0.3, 0.7]]), rng=rng_at).eval()
compare_sample_values(choice, [1, 2, 3], (10, 2), replace=True)
compare_sample_values(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论