提交 bf519076 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Extend supported RandomVariables in JAX backend via rewrites

上级 d356950c
...@@ -260,23 +260,6 @@ def jax_sample_fn_t(op): ...@@ -260,23 +260,6 @@ def jax_sample_fn_t(op):
return sample_fn return sample_fn
@jax_sample_fn.register(aer.HalfNormalRV)
def jax_sample_fn_halfnormal(op):
"""JAX implementation of `HalfNormalRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
sample = (
loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype)) * scale
)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.ChoiceRV) @jax_sample_fn.register(aer.ChoiceRV)
def jax_funcify_choice(op): def jax_funcify_choice(op):
"""JAX implementation of `ChoiceRV`.""" """JAX implementation of `ChoiceRV`."""
...@@ -305,19 +288,3 @@ def jax_sample_fn_permutation(op): ...@@ -305,19 +288,3 @@ def jax_sample_fn_permutation(op):
return (rng, sample) return (rng, sample)
return sample_fn return sample_fn
@jax_sample_fn.register(aer.LogNormalRV)
def jax_sample_fn_lognormal(op):
"""JAX implementation of `LogNormalRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
sample = loc + jax.random.normal(sampling_key, size, dtype) * scale
sample_exp = jax.numpy.exp(sample)
rng["jax_state"] = rng_key
return (rng, sample_exp)
return sample_fn
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.tensor.basic import MakeVector from pytensor.graph.rewriting.db import SequenceDB
from pytensor.tensor import abs as abs_t
from pytensor.tensor import exp, floor, log, log1p, reciprocal, sqrt
from pytensor.tensor.basic import MakeVector, cast, ones_like, switch, zeros_like
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
ChiSquareRV,
GenGammaRV,
GeometricRV,
HalfNormalRV,
InvGammaRV,
LogNormalRV,
NegBinomialRV,
WaldRV,
gamma,
normal,
poisson,
uniform,
)
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
...@@ -47,6 +64,118 @@ def size_parameter_as_tuple(fgraph, node): ...@@ -47,6 +64,118 @@ def size_parameter_as_tuple(fgraph, node):
return new_node.outputs return new_node.outputs
@node_rewriter([LogNormalRV])
def lognormal_from_normal(fgraph, node):
next_rng, n = normal.make_node(*node.inputs).outputs
return [next_rng, exp(n)]
@node_rewriter([HalfNormalRV])
def halfnormal_from_normal(fgraph, node):
*other_inputs, loc, scale = node.inputs
next_rng, n = normal.make_node(*other_inputs, zeros_like(loc), scale).outputs
h = abs_t(n) + loc
return [next_rng, cast(h, dtype=node.default_output().dtype)]
@node_rewriter([GeometricRV])
def geometric_from_uniform(fgraph, node):
*other_inputs, p = node.inputs
next_rng, u = uniform.make_node(*other_inputs, zeros_like(p), 1).outputs
g = floor(log(u) / log1p(-p)) + 1
return [next_rng, cast(g, dtype=node.default_output().dtype)]
@node_rewriter([NegBinomialRV])
def negative_binomial_from_gamma_poisson(fgraph, node):
rng, *other_inputs, n, p = node.inputs
next_rng, g = gamma.make_node(rng, *other_inputs, n, p / (1 - p)).outputs
next_rng, p = poisson.make_node(next_rng, *other_inputs, g).outputs
return [next_rng, p]
@node_rewriter([InvGammaRV])
def inverse_gamma_from_gamma(fgraph, node):
*other_inputs, shape, scale = node.inputs
next_rng, g = gamma.make_node(*other_inputs, shape, scale).outputs
return [next_rng, reciprocal(g)]
@node_rewriter([ChiSquareRV])
def chi_square_from_gamma(fgraph, node):
*other_inputs, df = node.inputs
next_rng, g = gamma.make_node(*other_inputs, df / 2, 1 / 2).outputs
return [next_rng, g]
@node_rewriter([GenGammaRV])
def generalized_gamma_from_gamma(fgraph, node):
*other_inputs, alpha, p, lambd = node.inputs
next_rng, g = gamma.make_node(*other_inputs, alpha / p, ones_like(lambd)).outputs
g = (g ** reciprocal(p)) * lambd
return [next_rng, cast(g, dtype=node.default_output().dtype)]
@node_rewriter([WaldRV])
def wald_from_normal_uniform(fgraph, node):
rng, *other_inputs, mean, scale = node.inputs
next_rng, n = normal.make_node(
rng, *other_inputs, zeros_like(mean), ones_like(scale)
).outputs
next_rng, u = uniform.make_node(
next_rng, *other_inputs, zeros_like(mean), ones_like(scale)
).outputs
mu_2l = mean / (2 * scale)
y = mean * n * n
x = mean + mu_2l * (y - sqrt(4 * scale * y + y * y))
w = switch(u <= mean / (mean + x), x, mean * mean / x)
return [next_rng, cast(w, dtype=node.default_output().dtype)]
random_vars_opt = SequenceDB()
random_vars_opt.register(
"lognormal_from_normal",
in2out(lognormal_from_normal),
"jax",
)
random_vars_opt.register(
"halfnormal_from_normal",
in2out(halfnormal_from_normal),
"jax",
)
random_vars_opt.register(
"geometric_from_uniform",
in2out(geometric_from_uniform),
"jax",
)
random_vars_opt.register(
"negative_binomial_from_gamma_poisson",
in2out(negative_binomial_from_gamma_poisson),
"jax",
)
random_vars_opt.register(
"inverse_gamma_from_gamma",
in2out(inverse_gamma_from_gamma),
"jax",
)
random_vars_opt.register(
"chi_square_from_gamma",
in2out(chi_square_from_gamma),
"jax",
)
random_vars_opt.register(
"generalized_gamma_from_gamma",
in2out(generalized_gamma_from_gamma),
"jax",
)
random_vars_opt.register(
"wald_from_normal_uniform",
in2out(wald_from_normal_uniform),
"jax",
)
optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110)
optdb.register( optdb.register(
"jax_size_parameter_as_tuple", in2out(size_parameter_as_tuple), "jax", position=100 "jax_size_parameter_as_tuple", in2out(size_parameter_as_tuple), "jax", position=100
) )
...@@ -179,7 +179,7 @@ def test_random_updates(rng_ctor): ...@@ -179,7 +179,7 @@ def test_random_updates(rng_ctor):
], ],
(2,), (2,),
"lognorm", "lognorm",
lambda *args: args, lambda mu, sigma: (sigma, 0, np.exp(mu)),
), ),
( (
aer.normal, aer.normal,
...@@ -285,7 +285,7 @@ def test_random_updates(rng_ctor): ...@@ -285,7 +285,7 @@ def test_random_updates(rng_ctor):
[ [
set_test_value( set_test_value(
at.dvector(), at.dvector(),
np.array([-1.0, 2.0], dtype=np.float64), np.array([-1.0, 200.0], dtype=np.float64),
), ),
set_test_value( set_test_value(
at.dscalar(), at.dscalar(),
...@@ -296,6 +296,71 @@ def test_random_updates(rng_ctor): ...@@ -296,6 +296,71 @@ def test_random_updates(rng_ctor):
"halfnorm", "halfnorm",
lambda *args: args, lambda *args: args,
), ),
(
aer.invgamma,
[
set_test_value(
at.dvector(),
np.array([10.4, 2.8], dtype=np.float64),
),
set_test_value(
at.dvector(),
np.array([3.4, 7.3], dtype=np.float64),
),
],
(2,),
"invgamma",
lambda a, b: (a, 0, b),
),
(
aer.chisquare,
[
set_test_value(
at.dvector(),
np.array([2.4, 4.9], dtype=np.float64),
),
],
(2,),
"chi2",
lambda *args: args,
),
(
aer.gengamma,
[
set_test_value(
at.dvector(),
np.array([10.4, 2.8], dtype=np.float64),
),
set_test_value(
at.dvector(),
np.array([3.4, 7.3], dtype=np.float64),
),
set_test_value(
at.dvector(),
np.array([0.9, 2.0], dtype=np.float64),
),
],
(2,),
"gengamma",
lambda alpha, p, lambd: (alpha / p, p, 0, lambd),
),
(
aer.wald,
[
set_test_value(
at.dvector(),
np.array([10.4, 2.8], dtype=np.float64),
),
set_test_value(
at.dvector(),
np.array([4.5, 2.0], dtype=np.float64),
),
],
(2,),
"invgauss",
# https://stackoverflow.com/a/48603469
lambda mean, scale: (mean / scale, 0, scale),
),
], ],
) )
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv): def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):
...@@ -329,7 +394,8 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c ...@@ -329,7 +394,8 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
test_res = stats.cramervonmises( test_res = stats.cramervonmises(
samples[(Ellipsis,) + idx], cdf_name, args=cdf_params samples[(Ellipsis,) + idx], cdf_name, args=cdf_params
) )
assert test_res.pvalue > 0.1 assert not np.isnan(test_res.statistic)
assert test_res.pvalue > 0.01
@pytest.mark.parametrize("size", [(), (4,)]) @pytest.mark.parametrize("size", [(), (4,)])
...@@ -410,6 +476,29 @@ def test_random_permutation(): ...@@ -410,6 +476,29 @@ def test_random_permutation():
np.testing.assert_allclose(array, permuted) np.testing.assert_allclose(array, permuted)
def test_random_geometric():
rng = shared(np.random.RandomState(123))
p = np.array([0.3, 0.7])
g = at.random.geometric(p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
def test_negative_binomial():
rng = shared(np.random.RandomState(123))
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = at.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0), np.sqrt(n * (1 - p) / p**2), rtol=0.1
)
def test_random_unimplemented(): def test_random_unimplemented():
"""Compiling a graph with a non-supported `RandomVariable` should """Compiling a graph with a non-supported `RandomVariable` should
raise an error. raise an error.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论