提交 c67741c0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make ScipyRandomVariable results writeable

上级 fc985340
......@@ -25,15 +25,15 @@ except AttributeError:
class ScipyRandomVariable(RandomVariable):
r"""A class for `RandomVariable`\s that use SciPy-based samplers.
r"""A class for straightforward `RandomVariable`\s that use SciPy-based samplers.
This will only work for `RandomVariable`\s for which the output shape is
entirely determined by broadcasting the distribution parameters (e.g. basic
scalar distributions).
By "straightforward" we mean `RandomVariable`\s for which the output shape
is entirely determined by broadcasting the distribution parameters
(e.g. basic scalar distributions).
The more sophisticated shape logic performed by `RandomVariable` is avoided
in order to reduce the amount of unnecessary extra steps taken to correct
SciPy's shape-removing defect.
in order to reduce the amount of unnecessary steps taken to correct SciPy's
shape-reducing defects.
"""
......@@ -53,12 +53,20 @@ class ScipyRandomVariable(RandomVariable):
def rng_fn(cls, *args, **kwargs):
size = args[-1]
res = cls.rng_fn_scipy(*args, **kwargs)
return np.broadcast_to(
res,
size
if size is not None
else broadcast_shapes(*[np.shape(a) for a in args[1:-1]]),
)
if np.ndim(res) == 0:
# The sample is an `np.number`, and is not writeable, or non-NumPy
# type, so we need to clone/create a usable NumPy result
return np.asarray(res)
if size is None:
# SciPy will sometimes drop broadcastable dimensions; we need to
# check and, if necessary, add them back
exp_shape = broadcast_shapes(*[np.shape(a) for a in args[1:-1]])
if res.shape != exp_shape:
return np.broadcast_to(res, exp_shape).copy()
return res
class UniformRV(RandomVariable):
......
......@@ -124,6 +124,8 @@ def compare_sample_values(rv, *params, rng=None, test_fn=None, **kwargs):
aesara_res_val = aesara_fn()
assert aesara_res_val.flags.writeable
np.testing.assert_array_equal(aesara_res_val.shape, numpy_res.shape)
np.testing.assert_allclose(aesara_res_val, numpy_res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论