提交 c19ac799 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix ScipyRandomVariable issue when size is None and parameters are all broadcastable

上级 1a85652e
...@@ -57,7 +57,7 @@ class ScipyRandomVariable(RandomVariable): ...@@ -57,7 +57,7 @@ class ScipyRandomVariable(RandomVariable):
if np.ndim(res) == 0: if np.ndim(res) == 0:
# The sample is an `np.number`, and is not writeable, or non-NumPy # The sample is an `np.number`, and is not writeable, or non-NumPy
# type, so we need to clone/create a usable NumPy result # type, so we need to clone/create a usable NumPy result
return np.asarray(res) res = np.asarray(res)
if size is None: if size is None:
# SciPy will sometimes drop broadcastable dimensions; we need to # SciPy will sometimes drop broadcastable dimensions; we need to
......
...@@ -787,6 +787,11 @@ def test_hypergeometric_samples(ngood, nbad, nsample, size): ...@@ -787,6 +787,11 @@ def test_hypergeometric_samples(ngood, nbad, nsample, size):
"loc, scale, size", "loc, scale, size",
[ [
(np.array(10, dtype=config.floatX), np.array(0.1, dtype=config.floatX), None), (np.array(10, dtype=config.floatX), np.array(0.1, dtype=config.floatX), None),
(
np.array([[0]], dtype=config.floatX),
np.array([[1]], dtype=config.floatX),
None,
),
(np.array(10, dtype=config.floatX), np.array(0.1, dtype=config.floatX), []), (np.array(10, dtype=config.floatX), np.array(0.1, dtype=config.floatX), []),
(np.array(10, dtype=config.floatX), np.array(0.1, dtype=config.floatX), [2, 3]), (np.array(10, dtype=config.floatX), np.array(0.1, dtype=config.floatX), [2, 3]),
( (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论