提交 87bc36c7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Work around squeeze bug in SciPy samplers

上级 2b78c67f
import abc
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
...@@ -22,6 +23,53 @@ except ImportError: # pragma: no cover ...@@ -22,6 +23,53 @@ except ImportError: # pragma: no cover
raise RuntimeError("pypolygamma not installed!") raise RuntimeError("pypolygamma not installed!")
try:
broadcast_shapes = np.broadcast_shapes
except AttributeError:
from numpy.lib.stride_tricks import _broadcast_shape
def broadcast_shapes(*shapes):
arrays = [np.empty(x, dtype=[]) for x in shapes]
return _broadcast_shape(arrays)
class ScipyRandomVariable(RandomVariable):
r"""A class for `RandomVariable`\s that use SciPy-based samplers.
This will only work for `RandomVariable`\s for which the output shape is
entirely determined by broadcasting the distribution parameters (e.g. basic
scalar distributions).
The more sophisticated shape logic performed by `RandomVariable` is avoided
in order to reduce the amount of unnecessary extra steps taken to correct
SciPy's shape-removing defect.
"""
@classmethod
@abc.abstractmethod
def rng_fn_scipy(cls, rng, *args, **kwargs):
r"""
`RandomVariable`\s implementations that want to use SciPy-based samplers
need to implement this method instead of the base
`RandomVariable.rng_fn`; otherwise their broadcast dimensions will be
dropped by SciPy.
"""
@classmethod
def rng_fn(cls, *args, **kwargs):
size = args[-1]
res = cls.rng_fn_scipy(*args, **kwargs)
return np.broadcast_to(
res,
size
if size is not None
else broadcast_shapes(*[np.shape(a) for a in args[1:-1]]),
)
class UniformRV(RandomVariable): class UniformRV(RandomVariable):
name = "uniform" name = "uniform"
ndim_supp = 0 ndim_supp = 0
...@@ -72,7 +120,7 @@ class NormalRV(RandomVariable): ...@@ -72,7 +120,7 @@ class NormalRV(RandomVariable):
normal = NormalRV() normal = NormalRV()
class HalfNormalRV(RandomVariable): class HalfNormalRV(ScipyRandomVariable):
name = "halfnormal" name = "halfnormal"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
...@@ -83,7 +131,7 @@ class HalfNormalRV(RandomVariable): ...@@ -83,7 +131,7 @@ class HalfNormalRV(RandomVariable):
return super().__call__(loc, scale, size=size, **kwargs) return super().__call__(loc, scale, size=size, **kwargs)
@classmethod @classmethod
def rng_fn(cls, rng, loc, scale, size): def rng_fn_scipy(cls, rng, loc, scale, size):
return stats.halfnorm.rvs(loc, scale, random_state=rng, size=size) return stats.halfnorm.rvs(loc, scale, random_state=rng, size=size)
...@@ -104,7 +152,7 @@ class LogNormalRV(RandomVariable): ...@@ -104,7 +152,7 @@ class LogNormalRV(RandomVariable):
lognormal = LogNormalRV() lognormal = LogNormalRV()
class GammaRV(RandomVariable): class GammaRV(ScipyRandomVariable):
name = "gamma" name = "gamma"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
...@@ -115,7 +163,7 @@ class GammaRV(RandomVariable): ...@@ -115,7 +163,7 @@ class GammaRV(RandomVariable):
return super().__call__(shape, 1.0 / rate, size=size, **kwargs) return super().__call__(shape, 1.0 / rate, size=size, **kwargs)
@classmethod @classmethod
def rng_fn(cls, rng, shape, scale, size): def rng_fn_scipy(cls, rng, shape, scale, size):
return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng) return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng)
...@@ -133,7 +181,7 @@ class ChiSquareRV(RandomVariable): ...@@ -133,7 +181,7 @@ class ChiSquareRV(RandomVariable):
chisquare = ChiSquareRV() chisquare = ChiSquareRV()
class ParetoRV(RandomVariable): class ParetoRV(ScipyRandomVariable):
name = "pareto" name = "pareto"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
...@@ -144,14 +192,14 @@ class ParetoRV(RandomVariable): ...@@ -144,14 +192,14 @@ class ParetoRV(RandomVariable):
return super().__call__(b, scale, size=size, **kwargs) return super().__call__(b, scale, size=size, **kwargs)
@classmethod @classmethod
def rng_fn(cls, rng, b, scale, size): def rng_fn_scipy(cls, rng, b, scale, size):
return stats.pareto.rvs(b, scale=scale, size=size, random_state=rng) return stats.pareto.rvs(b, scale=scale, size=size, random_state=rng)
pareto = ParetoRV() pareto = ParetoRV()
class GumbelRV(RandomVariable): class GumbelRV(ScipyRandomVariable):
name = "gumbel" name = "gumbel"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
...@@ -163,12 +211,12 @@ class GumbelRV(RandomVariable): ...@@ -163,12 +211,12 @@ class GumbelRV(RandomVariable):
loc: Union[np.ndarray, float], loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float] = 1.0, scale: Union[np.ndarray, float] = 1.0,
size: Optional[Union[List[int], int]] = None, size: Optional[Union[List[int], int]] = None,
**kwargs **kwargs,
) -> RandomVariable: ) -> RandomVariable:
return super().__call__(loc, scale, size=size, **kwargs) return super().__call__(loc, scale, size=size, **kwargs)
@classmethod @classmethod
def rng_fn( def rng_fn_scipy(
cls, cls,
rng: Union[np.random.Generator, np.random.RandomState], rng: Union[np.random.Generator, np.random.RandomState],
loc: Union[np.ndarray, float], loc: Union[np.ndarray, float],
...@@ -356,7 +404,7 @@ class HyperGeometricRV(RandomVariable): ...@@ -356,7 +404,7 @@ class HyperGeometricRV(RandomVariable):
hypergeometric = HyperGeometricRV() hypergeometric = HyperGeometricRV()
class CauchyRV(RandomVariable): class CauchyRV(ScipyRandomVariable):
name = "cauchy" name = "cauchy"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
...@@ -367,14 +415,14 @@ class CauchyRV(RandomVariable): ...@@ -367,14 +415,14 @@ class CauchyRV(RandomVariable):
return super().__call__(loc, scale, size=size, **kwargs) return super().__call__(loc, scale, size=size, **kwargs)
@classmethod @classmethod
def rng_fn(cls, rng, loc, scale, size): def rng_fn_scipy(cls, rng, loc, scale, size):
return stats.cauchy.rvs(loc=loc, scale=scale, random_state=rng, size=size) return stats.cauchy.rvs(loc=loc, scale=scale, random_state=rng, size=size)
cauchy = CauchyRV() cauchy = CauchyRV()
class HalfCauchyRV(RandomVariable): class HalfCauchyRV(ScipyRandomVariable):
name = "halfcauchy" name = "halfcauchy"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
...@@ -385,14 +433,14 @@ class HalfCauchyRV(RandomVariable): ...@@ -385,14 +433,14 @@ class HalfCauchyRV(RandomVariable):
return super().__call__(loc, scale, size=size, **kwargs) return super().__call__(loc, scale, size=size, **kwargs)
@classmethod @classmethod
def rng_fn(cls, rng, loc, scale, size): def rng_fn_scipy(cls, rng, loc, scale, size):
return stats.halfcauchy.rvs(loc=loc, scale=scale, random_state=rng, size=size) return stats.halfcauchy.rvs(loc=loc, scale=scale, random_state=rng, size=size)
halfcauchy = HalfCauchyRV() halfcauchy = HalfCauchyRV()
class InvGammaRV(RandomVariable): class InvGammaRV(ScipyRandomVariable):
name = "invgamma" name = "invgamma"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
...@@ -400,7 +448,7 @@ class InvGammaRV(RandomVariable): ...@@ -400,7 +448,7 @@ class InvGammaRV(RandomVariable):
_print_name = ("InvGamma", "\\operatorname{Gamma^{-1}}") _print_name = ("InvGamma", "\\operatorname{Gamma^{-1}}")
@classmethod @classmethod
def rng_fn(cls, rng, shape, rate, size=None): def rng_fn_scipy(cls, rng, shape, rate, size):
return stats.invgamma.rvs(shape, scale=rate, size=size, random_state=rng) return stats.invgamma.rvs(shape, scale=rate, size=size, random_state=rng)
...@@ -421,7 +469,7 @@ class WaldRV(RandomVariable): ...@@ -421,7 +469,7 @@ class WaldRV(RandomVariable):
wald = WaldRV() wald = WaldRV()
class TruncExponentialRV(RandomVariable): class TruncExponentialRV(ScipyRandomVariable):
name = "truncexpon" name = "truncexpon"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0, 0] ndims_params = [0, 0, 0]
...@@ -429,7 +477,7 @@ class TruncExponentialRV(RandomVariable): ...@@ -429,7 +477,7 @@ class TruncExponentialRV(RandomVariable):
_print_name = ("TruncExp", "\\operatorname{TruncExp}") _print_name = ("TruncExp", "\\operatorname{TruncExp}")
@classmethod @classmethod
def rng_fn(cls, rng, b, loc, scale, size=None): def rng_fn_scipy(cls, rng, b, loc, scale, size):
return stats.truncexpon.rvs( return stats.truncexpon.rvs(
b, loc=loc, scale=scale, size=size, random_state=rng b, loc=loc, scale=scale, size=size, random_state=rng
) )
...@@ -438,7 +486,7 @@ class TruncExponentialRV(RandomVariable): ...@@ -438,7 +486,7 @@ class TruncExponentialRV(RandomVariable):
truncexpon = TruncExponentialRV() truncexpon = TruncExponentialRV()
class BernoulliRV(RandomVariable): class BernoulliRV(ScipyRandomVariable):
name = "bernoulli" name = "bernoulli"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0] ndims_params = [0]
...@@ -446,7 +494,7 @@ class BernoulliRV(RandomVariable): ...@@ -446,7 +494,7 @@ class BernoulliRV(RandomVariable):
_print_name = ("Bern", "\\operatorname{Bern}") _print_name = ("Bern", "\\operatorname{Bern}")
@classmethod @classmethod
def rng_fn(cls, rng, p, size=None): def rng_fn_scipy(cls, rng, p, size):
return stats.bernoulli.rvs(p, size=size, random_state=rng) return stats.bernoulli.rvs(p, size=size, random_state=rng)
...@@ -475,7 +523,7 @@ class BinomialRV(RandomVariable): ...@@ -475,7 +523,7 @@ class BinomialRV(RandomVariable):
binomial = BinomialRV() binomial = BinomialRV()
class NegBinomialRV(RandomVariable): class NegBinomialRV(ScipyRandomVariable):
name = "nbinom" name = "nbinom"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0] ndims_params = [0, 0]
...@@ -483,14 +531,14 @@ class NegBinomialRV(RandomVariable): ...@@ -483,14 +531,14 @@ class NegBinomialRV(RandomVariable):
_print_name = ("NB", "\\operatorname{NB}") _print_name = ("NB", "\\operatorname{NB}")
@classmethod @classmethod
def rng_fn(cls, rng, n, p, size=None): def rng_fn_scipy(cls, rng, n, p, size):
return stats.nbinom.rvs(n, p, size=size, random_state=rng) return stats.nbinom.rvs(n, p, size=size, random_state=rng)
nbinom = NegBinomialRV() nbinom = NegBinomialRV()
class BetaBinomialRV(RandomVariable): class BetaBinomialRV(ScipyRandomVariable):
name = "beta_binomial" name = "beta_binomial"
ndim_supp = 0 ndim_supp = 0
ndims_params = [0, 0, 0] ndims_params = [0, 0, 0]
...@@ -498,7 +546,7 @@ class BetaBinomialRV(RandomVariable): ...@@ -498,7 +546,7 @@ class BetaBinomialRV(RandomVariable):
_print_name = ("BetaBinom", "\\operatorname{BetaBinom}") _print_name = ("BetaBinom", "\\operatorname{BetaBinom}")
@classmethod @classmethod
def rng_fn(cls, rng, n, a, b, size=None): def rng_fn_scipy(cls, rng, n, a, b, size):
return stats.betabinom.rvs(n, a, b, size=size, random_state=rng) return stats.betabinom.rvs(n, a, b, size=size, random_state=rng)
...@@ -535,7 +583,7 @@ class MultinomialRV(RandomVariable): ...@@ -535,7 +583,7 @@ class MultinomialRV(RandomVariable):
n = np.broadcast_to(n, size + n.shape) n = np.broadcast_to(n, size + n.shape)
p = np.broadcast_to(p, size + p.shape) p = np.broadcast_to(p, size + p.shape)
res = np.empty(p.shape) res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]): for idx in np.ndindex(p.shape[:-1]):
res[idx] = rng.multinomial(n[idx], p[idx]) res[idx] = rng.multinomial(n[idx], p[idx])
return res return res
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论