提交 c67741c0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make ScipyRandomVariable results writeable

上级 fc985340
...@@ -25,15 +25,15 @@ except AttributeError: ...@@ -25,15 +25,15 @@ except AttributeError:
class ScipyRandomVariable(RandomVariable): class ScipyRandomVariable(RandomVariable):
r"""A class for `RandomVariable`\s that use SciPy-based samplers. r"""A class for straightforward `RandomVariable`\s that use SciPy-based samplers.
This will only work for `RandomVariable`\s for which the output shape is By "straightforward" we mean `RandomVariable`\s for which the output shape
entirely determined by broadcasting the distribution parameters (e.g. basic is entirely determined by broadcasting the distribution parameters
scalar distributions). (e.g. basic scalar distributions).
The more sophisticated shape logic performed by `RandomVariable` is avoided The more sophisticated shape logic performed by `RandomVariable` is avoided
in order to reduce the amount of unnecessary extra steps taken to correct in order to reduce the amount of unnecessary steps taken to correct SciPy's
SciPy's shape-removing defect. shape-reducing defects.
""" """
...@@ -53,12 +53,20 @@ class ScipyRandomVariable(RandomVariable): ...@@ -53,12 +53,20 @@ class ScipyRandomVariable(RandomVariable):
def rng_fn(cls, *args, **kwargs): def rng_fn(cls, *args, **kwargs):
size = args[-1] size = args[-1]
res = cls.rng_fn_scipy(*args, **kwargs) res = cls.rng_fn_scipy(*args, **kwargs)
return np.broadcast_to(
res, if np.ndim(res) == 0:
size # The sample is an `np.number`, and is not writeable, or non-NumPy
if size is not None # type, so we need to clone/create a usable NumPy result
else broadcast_shapes(*[np.shape(a) for a in args[1:-1]]), return np.asarray(res)
)
if size is None:
# SciPy will sometimes drop broadcastable dimensions; we need to
# check and, if necessary, add them back
exp_shape = broadcast_shapes(*[np.shape(a) for a in args[1:-1]])
if res.shape != exp_shape:
return np.broadcast_to(res, exp_shape).copy()
return res
class UniformRV(RandomVariable): class UniformRV(RandomVariable):
......
...@@ -124,6 +124,8 @@ def compare_sample_values(rv, *params, rng=None, test_fn=None, **kwargs): ...@@ -124,6 +124,8 @@ def compare_sample_values(rv, *params, rng=None, test_fn=None, **kwargs):
aesara_res_val = aesara_fn() aesara_res_val = aesara_fn()
assert aesara_res_val.flags.writeable
np.testing.assert_array_equal(aesara_res_val.shape, numpy_res.shape) np.testing.assert_array_equal(aesara_res_val.shape, numpy_res.shape)
np.testing.assert_allclose(aesara_res_val, numpy_res) np.testing.assert_allclose(aesara_res_val, numpy_res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论