提交 d4fe4c3f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove PyPolyaGamma

上级 e95e7736
......@@ -15,14 +15,6 @@ from aesara.tensor.random.var import (
)
try:
from pypolyagamma import PyPolyaGamma
except ImportError: # pragma: no cover
def PyPolyaGamma(*args, **kwargs):
raise RuntimeError("pypolygamma not installed!")
try:
broadcast_shapes = np.broadcast_shapes
except AttributeError:
......@@ -635,47 +627,6 @@ class CategoricalRV(RandomVariable):
categorical = CategoricalRV()
class PolyaGammaRV(RandomVariable):
"""Polya-Gamma random variable.
XXX: This doesn't really use the given RNG, due to the narrowness of the
sampler package's implementation.
"""
name = "polya-gamma"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("PG", "\\operatorname{PG}")
@classmethod
def rng_fn(cls, rng, b, c, size):
rand_method = rng.integers if hasattr(rng, "integers") else rng.randint
pg = PyPolyaGamma(rand_method(2 ** 16))
if not size and b.shape == c.shape == ():
return pg.pgdraw(b, c)
else:
b, c = np.broadcast_arrays(b, c)
size = tuple(size or ())
if len(size) > 0:
b = np.broadcast_to(b, size)
c = np.broadcast_to(c, size)
smpl_val = np.empty(b.shape, dtype="double")
pg.pgdrawv(
np.asarray(b.flat).astype("double", copy=True),
np.asarray(c.flat).astype("double", copy=True),
np.asarray(smpl_val.flat),
)
return smpl_val
polyagamma = PolyaGammaRV()
class RandIntRV(RandomVariable):
name = "randint"
ndim_supp = 0
......
......@@ -14,7 +14,6 @@ dependencies:
- numpy
- scipy
- sympy
- pypolyagamma
# Intel BLAS
- mkl
- mkl-service
......
......@@ -45,7 +45,6 @@ from aesara.tensor.random.basic import (
pareto,
permutation,
poisson,
polyagamma,
randint,
standard_normal,
triangular,
......@@ -1157,39 +1156,6 @@ def test_categorical_basic():
categorical.rng_fn(rng, p, size=10)
@config.change_flags(compute_test_value="raise")
def test_polyagamma_samples():
_ = pytest.importorskip("pypolyagamma")
# Sampled values should be scalars
a = np.array(1.1, dtype=config.floatX)
b = np.array(-10.5, dtype=config.floatX)
pg_rv = polyagamma(a, b)
assert get_test_value(pg_rv).shape == ()
pg_rv = polyagamma(a, b, size=[1])
assert get_test_value(pg_rv).shape == (1,)
pg_rv = polyagamma(a, b, size=[2, 3])
bcast_smpl = get_test_value(pg_rv)
assert bcast_smpl.shape == (2, 3)
# Make sure they're not all equal
assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)
a = np.array([1.1, 3], dtype=config.floatX)
b = np.array(-10.5, dtype=config.floatX)
pg_rv = polyagamma(a, b)
bcast_smpl = get_test_value(pg_rv)
assert bcast_smpl.shape == (2,)
assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)
pg_rv = polyagamma(a, b, size=(3, 2))
bcast_smpl = get_test_value(pg_rv)
assert bcast_smpl.shape == (3, 2)
assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)
def test_randint_samples():
with pytest.raises(TypeError):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论