Unverified 提交 f951743d authored 作者: Adriano M. Yoshino's avatar Adriano M. Yoshino 提交者: GitHub

Implement betaincinv and gammainc[c]inv functions (#502)

上级 e9694031
......@@ -21,11 +21,14 @@ from pytensor.scalar.basic import (
Sub,
)
from pytensor.scalar.math import (
BetaIncInv,
Erf,
Erfc,
Erfcinv,
Erfcx,
Erfinv,
GammaIncCInv,
GammaIncInv,
Iv,
Ive,
Log1mexp,
......@@ -226,6 +229,20 @@ def jax_funcify_Second(op, **kwargs):
return second
@jax_funcify.register(GammaIncInv)
def jax_funcify_GammaIncInv(op, **kwargs):
gammaincinv = try_import_tfp_jax_op(op, jax_op_name="igammainv")
return gammaincinv
@jax_funcify.register(GammaIncCInv)
def jax_funcify_GammaIncCInv(op, **kwargs):
gammainccinv = try_import_tfp_jax_op(op, jax_op_name="igammacinv")
return gammainccinv
@jax_funcify.register(Erf)
def jax_funcify_Erf(op, node, **kwargs):
def erf(x):
......@@ -250,6 +267,7 @@ def jax_funcify_Erfinv(op, **kwargs):
return erfinv
@jax_funcify.register(BetaIncInv)
@jax_funcify.register(Erfcx)
@jax_funcify.register(Erfcinv)
def jax_funcify_from_tfp(op, **kwargs):
......
......@@ -733,6 +733,64 @@ class GammaIncC(BinaryScalarOp):
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
class GammaIncInv(BinaryScalarOp):
"""
Inverse to the regularized lower incomplete gamma function.
"""
nfunc_spec = ("scipy.special.gammaincinv", 2, 1)
@staticmethod
def st_impl(k, x):
return scipy.special.gammaincinv(k, x)
def impl(self, k, x):
return GammaIncInv.st_impl(k, x)
def grad(self, inputs, grads):
(k, x) = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, k),
gz * exp(gammaincinv(k, x)) * gamma(k) * (gammaincinv(k, x) ** (1 - k)),
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
gammaincinv = GammaIncInv(upgrade_to_float, name="gammaincinv")
class GammaIncCInv(BinaryScalarOp):
"""
Inverse to the regularized upper incomplete gamma function.
"""
nfunc_spec = ("scipy.special.gammainccinv", 2, 1)
@staticmethod
def st_impl(k, x):
return scipy.special.gammainccinv(k, x)
def impl(self, k, x):
return GammaIncCInv.st_impl(k, x)
def grad(self, inputs, grads):
(k, x) = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, k),
gz * -exp(gammainccinv(k, x)) * gamma(k) * (gammainccinv(k, x) ** (1 - k)),
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
gammainccinv = GammaIncCInv(upgrade_to_float, name="gammainccinv")
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop):
init = [as_scalar(x) if x is not None else None for x in init]
constant = [as_scalar(x) for x in constant]
......@@ -1648,6 +1706,43 @@ def betainc_grad(p, q, x, wrtp: bool):
return grad
class BetaIncInv(ScalarOp):
"""
Inverse of the regularized incomplete beta function.
"""
nfunc_spec = ("scipy.special.betaincinv", 3, 1)
def impl(self, a, b, x):
return scipy.special.betaincinv(a, b, x)
def grad(self, inputs, grads):
(a, b, x) = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, a),
grad_not_implemented(self, 0, b),
gz
* exp(betaln(a, b))
* ((1 - betaincinv(a, b, x)) ** (1 - b))
* (betaincinv(a, b, x) ** (1 - a)),
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
betaincinv = BetaIncInv(upgrade_to_float_no_complex, name="betaincinv")
def betaln(a, b):
"""
Beta function from gamma function.
"""
return gammaln(a) + gammaln(b) - gammaln(a + b)
class Hyp2F1(ScalarOp):
"""
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
......
......@@ -283,6 +283,16 @@ def gammal_inplace(k, x):
"""lower incomplete gamma function"""
@scalar_elemwise
def gammaincinv_inplace(k, x):
"""Inverse to the regularized lower incomplete gamma function"""
@scalar_elemwise
def gammainccinv_inplace(k, x):
"""Inverse of the regularized upper incomplete gamma function"""
@scalar_elemwise
def j0_inplace(x):
"""Bessel function of the first kind of order 0."""
......@@ -338,6 +348,11 @@ def betainc_inplace(a, b, x):
"""Regularized incomplete beta function"""
@scalar_elemwise
def betaincinv_inplace(a, b, x):
"""Inverse of the regularized incomplete beta function"""
@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
......
......@@ -1385,6 +1385,16 @@ def gammal(k, x):
"""Lower incomplete gamma function."""
@scalar_elemwise
def gammaincinv(k, x):
"""Inverse to the regularized lower incomplete gamma function"""
@scalar_elemwise
def gammainccinv(k, x):
"""Inverse of the regularized upper incomplete gamma function"""
@scalar_elemwise
def hyp2f1(a, b, c, z):
"""Gaussian hypergeometric function."""
......@@ -1451,6 +1461,11 @@ def betainc(a, b, x):
"""Regularized incomplete beta function"""
@scalar_elemwise
def betaincinv(a, b, x):
"""Inverse of the regularized incomplete beta function"""
@scalar_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`."""
......@@ -3044,6 +3059,8 @@ __all__ = [
"gammaincc",
"gammau",
"gammal",
"gammaincinv",
"gammainccinv",
"j0",
"j1",
"jv",
......@@ -3057,6 +3074,7 @@ __all__ = [
"log1pexp",
"log1mexp",
"betainc",
"betaincinv",
"real",
"imag",
"angle",
......
......@@ -6,7 +6,7 @@ import scipy
from pytensor.graph.basic import Apply
from pytensor.link.c.op import COp
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import gamma, neg, sum
from pytensor.tensor.math import gamma, gammaln, neg, sum
class SoftmaxGrad(COp):
......@@ -752,9 +752,27 @@ def factorial(n):
return gamma(n + 1)
def beta(a, b):
"""
Beta function.
"""
return (gamma(a) * gamma(b)) / gamma(a + b)
def betaln(a, b):
"""
Log beta function.
"""
return gammaln(a) + gammaln(b) - gammaln(a + b)
__all__ = [
"softmax",
"log_softmax",
"poch",
"factorial",
"beta",
"betaln",
]
......@@ -11,12 +11,15 @@ from pytensor.tensor import as_tensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import (
betaincinv,
cosh,
erf,
erfc,
erfcinv,
erfcx,
erfinv,
gammainccinv,
gammaincinv,
iv,
log,
log1mexp,
......@@ -165,6 +168,38 @@ def test_tfp_ops(op, test_values):
compare_jax_and_py(fg, test_values)
def test_betaincinv():
a = vector("a", dtype="float64")
b = vector("b", dtype="float64")
x = vector("x", dtype="float64")
out = betaincinv(a, b, x)
fg = FunctionGraph([a, b, x], [out])
compare_jax_and_py(
fg,
[
np.array([5.5, 7.0]),
np.array([5.5, 7.0]),
np.array([0.25, 0.7]),
],
)
def test_gammaincinv():
k = vector("k", dtype="float64")
x = vector("x", dtype="float64")
out = gammaincinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
def test_gammainccinv():
k = vector("k", dtype="float64")
x = vector("x", dtype="float64")
out = gammainccinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
def test_psi():
x = scalar("x")
out = psi(x)
......
......@@ -69,6 +69,8 @@ expected_gammainc = scipy.special.gammainc
expected_gammaincc = scipy.special.gammaincc
expected_gammau = scipy_special_gammau
expected_gammal = scipy_special_gammal
expected_gammaincinv = scipy.special.gammaincinv
expected_gammainccinv = scipy.special.gammainccinv
expected_j0 = scipy.special.j0
expected_j1 = scipy.special.j1
expected_jv = scipy.special.jv
......@@ -79,6 +81,7 @@ expected_ive = scipy.special.ive
expected_erfcx = scipy.special.erfcx
expected_sigmoid = scipy.special.expit
expected_hyp2f1 = scipy.special.hyp2f1
expected_betaincinv = scipy.special.betaincinv
TestErfBroadcast = makeBroadcastTester(
op=pt.erf,
......@@ -484,6 +487,49 @@ TestGammaLInplaceBroadcast = makeBroadcastTester(
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_binary_gamma = dict(
normal=(
random_ranged(0, 100, (2, 3), rng=rng),
random_ranged(0, 1, (2, 3), rng=rng),
),
empty=(np.asarray([], dtype=config.floatX), np.asarray([], dtype=config.floatX)),
)
TestGammaIncInvBroadcast = makeBroadcastTester(
op=pt.gammaincinv,
expected=expected_gammaincinv,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
)
TestGammaIncInvInplaceBroadcast = makeBroadcastTester(
op=inplace.gammaincinv_inplace,
expected=expected_gammaincinv,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
TestGammaInccInvBroadcast = makeBroadcastTester(
op=pt.gammainccinv,
expected=expected_gammainccinv,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
)
TestGammaInccInvInplaceBroadcast = makeBroadcastTester(
op=inplace.gammainccinv_inplace,
expected=expected_gammainccinv,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_unary_bessel = dict(
normal=(random_ranged(-10, 10, (2, 3), rng=rng),),
......@@ -880,6 +926,27 @@ class TestBetaIncGrad:
)
_good_broadcast_ternary_betaincinv = dict(
normal=(
random_ranged(0, 1000, (2, 3)),
random_ranged(0, 1000, (2, 3)),
random_ranged(0, 1, (2, 3)),
),
)
TestBetaincinvBroadcast = makeBroadcastTester(
op=pt.betaincinv,
expected=scipy.special.betaincinv,
good=_good_broadcast_ternary_betaincinv,
)
TestBetaincinvInplaceBroadcast = makeBroadcastTester(
op=inplace.betaincinv_inplace,
expected=scipy.special.betaincinv,
good=_good_broadcast_ternary_betaincinv,
inplace=True,
)
_good_broadcast_quaternary_hyp2f1 = dict(
normal=(
random_ranged(0, 20, (2, 3)),
......
import numpy as np
import pytest
from scipy.special import beta as scipy_beta
from scipy.special import factorial as scipy_factorial
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import poch as scipy_poch
......@@ -11,6 +12,8 @@ from pytensor.tensor.special import (
LogSoftmax,
Softmax,
SoftmaxGrad,
beta,
betaln,
factorial,
log_softmax,
poch,
......@@ -171,3 +174,29 @@ def test_factorial(n):
np.testing.assert_allclose(
actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
def test_beta():
_a, _b = vectors("a", "b")
actual_fn = function([_a, _b], beta(_a, _b))
a = random_ranged(0, 5, (2,))
b = random_ranged(0, 5, (2,))
actual = actual_fn(a, b)
expected = scipy_beta(a, b)
np.testing.assert_allclose(
actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
def test_betaln():
_a, _b = vectors("a", "b")
actual_fn = function([_a, _b], betaln(_a, _b))
a = random_ranged(0, 5, (2,))
b = random_ranged(0, 5, (2,))
actual = np.exp(actual_fn(a, b))
expected = scipy_beta(a, b)
np.testing.assert_allclose(
actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论