提交 87bc36c7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Work around squeeze bug in SciPy samplers

上级 2b78c67f
import abc
from typing import List, Optional, Union
import numpy as np
......@@ -22,6 +23,53 @@ except ImportError: # pragma: no cover
raise RuntimeError("pypolygamma not installed!")
try:
broadcast_shapes = np.broadcast_shapes
except AttributeError:
from numpy.lib.stride_tricks import _broadcast_shape
def broadcast_shapes(*shapes):
arrays = [np.empty(x, dtype=[]) for x in shapes]
return _broadcast_shape(arrays)
class ScipyRandomVariable(RandomVariable):
r"""A class for `RandomVariable`\s that use SciPy-based samplers.
This will only work for `RandomVariable`\s for which the output shape is
entirely determined by broadcasting the distribution parameters (e.g. basic
scalar distributions).
The more sophisticated shape logic performed by `RandomVariable` is avoided
in order to reduce the amount of unnecessary extra steps taken to correct
SciPy's shape-removing defect.
"""
@classmethod
@abc.abstractmethod
def rng_fn_scipy(cls, rng, *args, **kwargs):
r"""
`RandomVariable`\s implementations that want to use SciPy-based samplers
need to implement this method instead of the base
`RandomVariable.rng_fn`; otherwise their broadcast dimensions will be
dropped by SciPy.
"""
@classmethod
def rng_fn(cls, *args, **kwargs):
size = args[-1]
res = cls.rng_fn_scipy(*args, **kwargs)
return np.broadcast_to(
res,
size
if size is not None
else broadcast_shapes(*[np.shape(a) for a in args[1:-1]]),
)
class UniformRV(RandomVariable):
name = "uniform"
ndim_supp = 0
......@@ -72,7 +120,7 @@ class NormalRV(RandomVariable):
normal = NormalRV()
class HalfNormalRV(RandomVariable):
class HalfNormalRV(ScipyRandomVariable):
name = "halfnormal"
ndim_supp = 0
ndims_params = [0, 0]
......@@ -83,7 +131,7 @@ class HalfNormalRV(RandomVariable):
return super().__call__(loc, scale, size=size, **kwargs)
@classmethod
def rng_fn(cls, rng, loc, scale, size):
def rng_fn_scipy(cls, rng, loc, scale, size):
return stats.halfnorm.rvs(loc, scale, random_state=rng, size=size)
......@@ -104,7 +152,7 @@ class LogNormalRV(RandomVariable):
lognormal = LogNormalRV()
class GammaRV(RandomVariable):
class GammaRV(ScipyRandomVariable):
name = "gamma"
ndim_supp = 0
ndims_params = [0, 0]
......@@ -115,7 +163,7 @@ class GammaRV(RandomVariable):
return super().__call__(shape, 1.0 / rate, size=size, **kwargs)
@classmethod
def rng_fn(cls, rng, shape, scale, size):
def rng_fn_scipy(cls, rng, shape, scale, size):
return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng)
......@@ -133,7 +181,7 @@ class ChiSquareRV(RandomVariable):
chisquare = ChiSquareRV()
class ParetoRV(RandomVariable):
class ParetoRV(ScipyRandomVariable):
name = "pareto"
ndim_supp = 0
ndims_params = [0, 0]
......@@ -144,14 +192,14 @@ class ParetoRV(RandomVariable):
return super().__call__(b, scale, size=size, **kwargs)
@classmethod
def rng_fn(cls, rng, b, scale, size):
def rng_fn_scipy(cls, rng, b, scale, size):
return stats.pareto.rvs(b, scale=scale, size=size, random_state=rng)
pareto = ParetoRV()
class GumbelRV(RandomVariable):
class GumbelRV(ScipyRandomVariable):
name = "gumbel"
ndim_supp = 0
ndims_params = [0, 0]
......@@ -163,12 +211,12 @@ class GumbelRV(RandomVariable):
loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float] = 1.0,
size: Optional[Union[List[int], int]] = None,
**kwargs
**kwargs,
) -> RandomVariable:
return super().__call__(loc, scale, size=size, **kwargs)
@classmethod
def rng_fn(
def rng_fn_scipy(
cls,
rng: Union[np.random.Generator, np.random.RandomState],
loc: Union[np.ndarray, float],
......@@ -356,7 +404,7 @@ class HyperGeometricRV(RandomVariable):
hypergeometric = HyperGeometricRV()
class CauchyRV(RandomVariable):
class CauchyRV(ScipyRandomVariable):
name = "cauchy"
ndim_supp = 0
ndims_params = [0, 0]
......@@ -367,14 +415,14 @@ class CauchyRV(RandomVariable):
return super().__call__(loc, scale, size=size, **kwargs)
@classmethod
def rng_fn(cls, rng, loc, scale, size):
def rng_fn_scipy(cls, rng, loc, scale, size):
return stats.cauchy.rvs(loc=loc, scale=scale, random_state=rng, size=size)
cauchy = CauchyRV()
class HalfCauchyRV(RandomVariable):
class HalfCauchyRV(ScipyRandomVariable):
name = "halfcauchy"
ndim_supp = 0
ndims_params = [0, 0]
......@@ -385,14 +433,14 @@ class HalfCauchyRV(RandomVariable):
return super().__call__(loc, scale, size=size, **kwargs)
@classmethod
def rng_fn(cls, rng, loc, scale, size):
def rng_fn_scipy(cls, rng, loc, scale, size):
return stats.halfcauchy.rvs(loc=loc, scale=scale, random_state=rng, size=size)
halfcauchy = HalfCauchyRV()
class InvGammaRV(RandomVariable):
class InvGammaRV(ScipyRandomVariable):
name = "invgamma"
ndim_supp = 0
ndims_params = [0, 0]
......@@ -400,7 +448,7 @@ class InvGammaRV(RandomVariable):
_print_name = ("InvGamma", "\\operatorname{Gamma^{-1}}")
@classmethod
def rng_fn(cls, rng, shape, rate, size=None):
def rng_fn_scipy(cls, rng, shape, rate, size):
return stats.invgamma.rvs(shape, scale=rate, size=size, random_state=rng)
......@@ -421,7 +469,7 @@ class WaldRV(RandomVariable):
wald = WaldRV()
class TruncExponentialRV(RandomVariable):
class TruncExponentialRV(ScipyRandomVariable):
name = "truncexpon"
ndim_supp = 0
ndims_params = [0, 0, 0]
......@@ -429,7 +477,7 @@ class TruncExponentialRV(RandomVariable):
_print_name = ("TruncExp", "\\operatorname{TruncExp}")
@classmethod
def rng_fn(cls, rng, b, loc, scale, size=None):
def rng_fn_scipy(cls, rng, b, loc, scale, size):
return stats.truncexpon.rvs(
b, loc=loc, scale=scale, size=size, random_state=rng
)
......@@ -438,7 +486,7 @@ class TruncExponentialRV(RandomVariable):
truncexpon = TruncExponentialRV()
class BernoulliRV(RandomVariable):
class BernoulliRV(ScipyRandomVariable):
name = "bernoulli"
ndim_supp = 0
ndims_params = [0]
......@@ -446,7 +494,7 @@ class BernoulliRV(RandomVariable):
_print_name = ("Bern", "\\operatorname{Bern}")
@classmethod
def rng_fn(cls, rng, p, size=None):
def rng_fn_scipy(cls, rng, p, size):
return stats.bernoulli.rvs(p, size=size, random_state=rng)
......@@ -475,7 +523,7 @@ class BinomialRV(RandomVariable):
binomial = BinomialRV()
class NegBinomialRV(RandomVariable):
class NegBinomialRV(ScipyRandomVariable):
name = "nbinom"
ndim_supp = 0
ndims_params = [0, 0]
......@@ -483,14 +531,14 @@ class NegBinomialRV(RandomVariable):
_print_name = ("NB", "\\operatorname{NB}")
@classmethod
def rng_fn(cls, rng, n, p, size=None):
def rng_fn_scipy(cls, rng, n, p, size):
return stats.nbinom.rvs(n, p, size=size, random_state=rng)
nbinom = NegBinomialRV()
class BetaBinomialRV(RandomVariable):
class BetaBinomialRV(ScipyRandomVariable):
name = "beta_binomial"
ndim_supp = 0
ndims_params = [0, 0, 0]
......@@ -498,7 +546,7 @@ class BetaBinomialRV(RandomVariable):
_print_name = ("BetaBinom", "\\operatorname{BetaBinom}")
@classmethod
def rng_fn(cls, rng, n, a, b, size=None):
def rng_fn_scipy(cls, rng, n, a, b, size):
return stats.betabinom.rvs(n, a, b, size=size, random_state=rng)
......@@ -535,7 +583,7 @@ class MultinomialRV(RandomVariable):
n = np.broadcast_to(n, size + n.shape)
p = np.broadcast_to(p, size + p.shape)
res = np.empty(p.shape)
res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]):
res[idx] = rng.multinomial(n[idx], p[idx])
return res
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论