提交 a576fa2c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Work-around for numpy bug in choice with size=()

上级 a5cb3b4f
......@@ -2084,6 +2084,10 @@ class ChoiceWithoutReplacement(RandomVariable):
batch_ndim = max(batch_ndim, size_ndim)
if batch_ndim == 0:
# Numpy choice fails with size=() if a.ndim > 1 is batched
# https://github.com/numpy/numpy/issues/26518
if core_shape == ():
core_shape = None
return rng.choice(a, p=p, size=core_shape, replace=False)
# Numpy choice doesn't have a concept of batch dims
......
......@@ -1422,6 +1422,15 @@ def test_choice_samples():
compare_sample_values(choice, pt.as_tensor_variable([1, 2, 3]), 2, replace=True)
def test_choice_scalar_size():
np.testing.assert_array_equal(
choice([[1, 2, 3]], size=(), replace=True).eval(), [1, 2, 3]
)
np.testing.assert_array_equal(
choice([[1, 2, 3]], size=(), replace=False).eval(), [1, 2, 3]
)
def test_permutation_samples():
compare_sample_values(
permutation,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论