提交 3169197c authored 作者: Trey Wenger's avatar Trey Wenger 提交者: Ricardo Vieira

Reparameterize GammaRV so beta is not inverted at each call

Also fix wrong JAX implementation of Gamma and Pareto RVs
上级 39aa1234
...@@ -216,21 +216,20 @@ def jax_sample_fn_uniform(op): ...@@ -216,21 +216,20 @@ def jax_sample_fn_uniform(op):
@jax_sample_fn.register(aer.ParetoRV) @jax_sample_fn.register(aer.ParetoRV)
@jax_sample_fn.register(aer.GammaRV) @jax_sample_fn.register(aer.GammaRV)
def jax_sample_fn_shape_rate(op): def jax_sample_fn_shape_scale(op):
"""JAX implementation of random variables in the shape-rate family. """JAX implementation of random variables in the shape-scale family.
JAX only implements the standard version of random variables in the JAX only implements the standard version of random variables in the
shape-rate family. We thus need to rescale the results manually. shape-scale family. We thus need to rescale the results manually.
""" """
name = op.name name = op.name
jax_op = getattr(jax.random, name) jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, shape, scale):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2) rng_key, sampling_key = jax.random.split(rng_key, 2)
(shape, rate) = parameters sample = jax_op(sampling_key, shape, size, dtype) * scale
sample = jax_op(sampling_key, shape, size, dtype) / rate
rng["jax_state"] = rng_key rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
......
import abc import abc
import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
...@@ -419,7 +420,7 @@ class LogNormalRV(RandomVariable): ...@@ -419,7 +420,7 @@ class LogNormalRV(RandomVariable):
lognormal = LogNormalRV() lognormal = LogNormalRV()
class GammaRV(ScipyRandomVariable): class GammaRV(RandomVariable):
r"""A gamma continuous random variable. r"""A gamma continuous random variable.
The probability density function for `gamma` in terms of the shape parameter The probability density function for `gamma` in terms of the shape parameter
...@@ -443,7 +444,7 @@ class GammaRV(ScipyRandomVariable): ...@@ -443,7 +444,7 @@ class GammaRV(ScipyRandomVariable):
dtype = "floatX" dtype = "floatX"
_print_name = ("Gamma", "\\operatorname{Gamma}") _print_name = ("Gamma", "\\operatorname{Gamma}")
def __call__(self, shape, rate, size=None, **kwargs): def __call__(self, shape, scale, size=None, **kwargs):
r"""Draw samples from a gamma distribution. r"""Draw samples from a gamma distribution.
Signature Signature
...@@ -455,8 +456,8 @@ class GammaRV(ScipyRandomVariable): ...@@ -455,8 +456,8 @@ class GammaRV(ScipyRandomVariable):
---------- ----------
shape shape
The shape :math:`\alpha` of the gamma distribution. Must be positive. The shape :math:`\alpha` of the gamma distribution. Must be positive.
rate scale
The rate :math:`\beta` of the gamma distribution. Must be positive. The scale :math:`1/\beta` of the gamma distribution. Must be positive.
size size
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
independent, identically distributed random variables are independent, identically distributed random variables are
...@@ -464,14 +465,26 @@ class GammaRV(ScipyRandomVariable): ...@@ -464,14 +465,26 @@ class GammaRV(ScipyRandomVariable):
is returned. is returned.
""" """
return super().__call__(shape, 1.0 / rate, size=size, **kwargs) return super().__call__(shape, scale, size=size, **kwargs)
@classmethod
def rng_fn_scipy(cls, rng, shape, scale, size):
return stats.gamma.rvs(shape, scale=scale, size=size, random_state=rng)
_gamma = GammaRV()
def gamma(shape, rate=None, scale=None, **kwargs):
# TODO: Remove helper when rate is deprecated
if rate is not None and scale is not None:
raise ValueError("Cannot specify both rate and scale")
elif rate is None and scale is None:
raise ValueError("Must specify scale")
elif rate is not None:
warnings.warn(
"Gamma rate argument is deprecated and will stop working, use scale instead",
FutureWarning,
)
scale = 1.0 / rate
gamma = GammaRV() return _gamma(shape, scale, **kwargs)
class ChiSquareRV(RandomVariable): class ChiSquareRV(RandomVariable):
......
...@@ -15,9 +15,9 @@ from pytensor.tensor.random.basic import ( ...@@ -15,9 +15,9 @@ from pytensor.tensor.random.basic import (
LogNormalRV, LogNormalRV,
NegBinomialRV, NegBinomialRV,
WaldRV, WaldRV,
_gamma,
beta, beta,
binomial, binomial,
gamma,
normal, normal,
poisson, poisson,
uniform, uniform,
...@@ -92,7 +92,7 @@ def geometric_from_uniform(fgraph, node): ...@@ -92,7 +92,7 @@ def geometric_from_uniform(fgraph, node):
@node_rewriter([NegBinomialRV]) @node_rewriter([NegBinomialRV])
def negative_binomial_from_gamma_poisson(fgraph, node): def negative_binomial_from_gamma_poisson(fgraph, node):
rng, *other_inputs, n, p = node.inputs rng, *other_inputs, n, p = node.inputs
next_rng, g = gamma.make_node(rng, *other_inputs, n, p / (1 - p)).outputs next_rng, g = _gamma.make_node(rng, *other_inputs, n, (1 - p) / p).outputs
next_rng, p = poisson.make_node(next_rng, *other_inputs, g).outputs next_rng, p = poisson.make_node(next_rng, *other_inputs, g).outputs
return [next_rng, p] return [next_rng, p]
...@@ -100,21 +100,21 @@ def negative_binomial_from_gamma_poisson(fgraph, node): ...@@ -100,21 +100,21 @@ def negative_binomial_from_gamma_poisson(fgraph, node):
@node_rewriter([InvGammaRV]) @node_rewriter([InvGammaRV])
def inverse_gamma_from_gamma(fgraph, node): def inverse_gamma_from_gamma(fgraph, node):
*other_inputs, shape, scale = node.inputs *other_inputs, shape, scale = node.inputs
next_rng, g = gamma.make_node(*other_inputs, shape, scale).outputs next_rng, g = _gamma.make_node(*other_inputs, shape, 1 / scale).outputs
return [next_rng, reciprocal(g)] return [next_rng, reciprocal(g)]
@node_rewriter([ChiSquareRV]) @node_rewriter([ChiSquareRV])
def chi_square_from_gamma(fgraph, node): def chi_square_from_gamma(fgraph, node):
*other_inputs, df = node.inputs *other_inputs, df = node.inputs
next_rng, g = gamma.make_node(*other_inputs, df / 2, 1 / 2).outputs next_rng, g = _gamma.make_node(*other_inputs, df / 2, 2).outputs
return [next_rng, g] return [next_rng, g]
@node_rewriter([GenGammaRV]) @node_rewriter([GenGammaRV])
def generalized_gamma_from_gamma(fgraph, node): def generalized_gamma_from_gamma(fgraph, node):
*other_inputs, alpha, p, lambd = node.inputs *other_inputs, alpha, p, lambd = node.inputs
next_rng, g = gamma.make_node(*other_inputs, alpha / p, ones_like(lambd)).outputs next_rng, g = _gamma.make_node(*other_inputs, alpha / p, ones_like(lambd)).outputs
g = (g ** reciprocal(p)) * lambd g = (g ** reciprocal(p)) * lambd
return [next_rng, cast(g, dtype=node.default_output().dtype)] return [next_rng, cast(g, dtype=node.default_output().dtype)]
......
...@@ -4,7 +4,7 @@ import scipy.stats as stats ...@@ -4,7 +4,7 @@ import scipy.stats as stats
import pytensor import pytensor
import pytensor.tensor as at import pytensor.tensor as at
import pytensor.tensor.random as aer import pytensor.tensor.random.basic as aer
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
...@@ -140,15 +140,15 @@ def test_random_updates_input_storage_order(): ...@@ -140,15 +140,15 @@ def test_random_updates_input_storage_order():
lambda *args: (0, args[0]), lambda *args: (0, args[0]),
), ),
( (
aer.gamma, aer._gamma,
[ [
set_test_value( set_test_value(
at.dvector(), at.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( set_test_value(
at.dscalar(), at.dvector(),
np.array(1.0, dtype=np.float64), np.array([0.5, 3.0], dtype=np.float64),
), ),
], ],
(2,), (2,),
...@@ -235,11 +235,15 @@ def test_random_updates_input_storage_order(): ...@@ -235,11 +235,15 @@ def test_random_updates_input_storage_order():
set_test_value( set_test_value(
at.dvector(), at.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
) ),
set_test_value(
at.dvector(),
np.array([2.0, 10.0], dtype=np.float64),
),
], ],
(2,), (2,),
"pareto", "pareto",
lambda *args: args, lambda shape, scale: (shape, 0.0, scale),
), ),
( (
aer.poisson, aer.poisson,
......
...@@ -92,6 +92,10 @@ rng = np.random.default_rng(42849) ...@@ -92,6 +92,10 @@ rng = np.random.default_rng(42849)
at.dvector(), at.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value(
at.dvector(),
np.array([2.0, 10.0], dtype=np.float64),
),
], ],
at.as_tensor([3, 2]), at.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Not implemented"), marks=pytest.mark.xfail(reason="Not implemented"),
...@@ -316,15 +320,15 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ...@@ -316,15 +320,15 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
lambda *args: args, lambda *args: args,
), ),
( (
aer.gamma, aer._gamma,
[ [
set_test_value( set_test_value(
at.dvector(), at.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( set_test_value(
at.dscalar(), at.dvector(),
np.array(1.0, dtype=np.float64), np.array([0.5, 3.0], dtype=np.float64),
), ),
], ],
(2,), (2,),
......
...@@ -17,6 +17,7 @@ from pytensor.graph.op import get_test_value ...@@ -17,6 +17,7 @@ from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.basic import ( from pytensor.tensor.random.basic import (
_gamma,
bernoulli, bernoulli,
beta, beta,
betabinom, betabinom,
...@@ -351,20 +352,31 @@ def test_lognormal_samples(mean, sigma, size): ...@@ -351,20 +352,31 @@ def test_lognormal_samples(mean, sigma, size):
], ],
) )
def test_gamma_samples(a, b, size): def test_gamma_samples(a, b, size):
gamma_test_fn = fixed_scipy_rvs("gamma")
def test_fn(shape, rate, **kwargs):
return gamma_test_fn(shape, scale=1.0 / rate, **kwargs)
compare_sample_values( compare_sample_values(
gamma, _gamma,
a, a,
b, b,
size=size, size=size,
test_fn=test_fn,
) )
def test_gamma_deprecation_wrapper_fn():
out = gamma(5.0, scale=0.5, size=(5,))
assert out.type.shape == (5,)
assert out.owner.inputs[-1].eval() == 0.5
with pytest.warns(FutureWarning, match="Gamma rate argument is deprecated"):
out = gamma([5.0, 10.0], 2.0, size=None)
assert out.type.shape == (2,)
assert out.owner.inputs[-1].eval() == 0.5
with pytest.raises(ValueError, match="Must specify scale"):
gamma(5.0)
with pytest.raises(ValueError, match="Cannot specify both rate and scale"):
gamma(5.0, rate=2.0, scale=0.5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"df, size", "df, size",
[ [
...@@ -470,18 +482,24 @@ def test_vonmises_samples(mu, kappa, size): ...@@ -470,18 +482,24 @@ def test_vonmises_samples(mu, kappa, size):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"alpha, size", "alpha, scale, size",
[ [
(np.array(0.5, dtype=config.floatX), None), (np.array(0.5, dtype=config.floatX), np.array(3.0, dtype=config.floatX), None),
(np.array(0.5, dtype=config.floatX), []), (np.array(0.5, dtype=config.floatX), np.array(5.0, dtype=config.floatX), []),
( (
np.full((1, 2), 0.5, dtype=config.floatX), np.full((1, 2), 0.5, dtype=config.floatX),
np.array([0.5, 1.0], dtype=config.floatX),
None, None,
), ),
], ],
) )
def test_pareto_samples(alpha, size): def test_pareto_samples(alpha, scale, size):
compare_sample_values(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto")) pareto_test_fn = fixed_scipy_rvs("pareto")
def test_fn(shape, scale, **kwargs):
return pareto_test_fn(shape, scale=scale, **kwargs)
compare_sample_values(pareto, alpha, scale, size=size, test_fn=test_fn)
def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论