提交 2b78c67f authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Implement betainc and derivatives

上级 b5313f1e
......@@ -5,6 +5,7 @@ As SciPy is not always available, we treat them separately.
"""
import os
import warnings
import numpy as np
import scipy.special
......@@ -14,12 +15,15 @@ from aesara.configdefaults import config
from aesara.gradient import grad_not_implemented
from aesara.scalar.basic import (
BinaryScalarOp,
ScalarOp,
UnaryScalarOp,
complex_types,
discrete_types,
exp,
float64,
float_types,
log,
log1p,
true_div,
upcast,
upgrade_to_float,
......@@ -1044,3 +1048,221 @@ class Log1mexp(UnaryScalarOp):
log1mexp = Log1mexp(upgrade_to_float, name="scalar_log1mexp")
class BetaInc(ScalarOp):
"""
Regularized incomplete beta function
"""
nin = 3
nfunc_spec = ("scipy.special.betainc", 3, 1)
def impl(self, a, b, x):
return scipy.special.betainc(a, b, x)
def grad(self, inp, grads):
a, b, x = inp
(gz,) = grads
return [
gz * betainc_der(a, b, x, True),
gz * betainc_der(a, b, x, False),
gz
* exp(
log1p(-x) * (b - 1)
+ log(x) * (a - 1)
- (gammaln(a) + gammaln(b) - gammaln(a + b))
),
]
betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")
class BetaIncDer(ScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the first
argument (alpha) or the second argument (bbeta), depending on whether the
fourth argument to betainc_der is `True` or `False`, respectively.
Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
Journal of Statistical Software, 3(1), 1-20.
"""
nin = 4
def impl(self, p, q, x, wrtp):
def _betainc_a_n(f, p, q, n):
"""
Numerator (a_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""
if n == 1:
return p * f * (q - 1) / (q * (p + 1))
p2n = p + 2 * n
F1 = p ** 2 * f ** 2 * (n - 1) / (q ** 2)
F2 = (
(p + q + n - 2)
* (p + n - 1)
* (q - n)
/ ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1))
)
return F1 * F2
def _betainc_b_n(f, p, q, n):
"""
Offset (b_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""
pf = p * f
p2n = p + 2 * n
N1 = 2 * (pf + 2 * q) * n * (n + p - 1) + p * q * (p - 2 - pf)
D1 = q * (p2n - 2) * p2n
return N1 / D1
def _betainc_da_n_dp(f, p, q, n):
"""
Derivative of a_n wrt p
"""
if n == 1:
return -p * f * (q - 1) / (q * (p + 1) ** 2)
pp = p ** 2
ppp = pp * p
p2n = p + 2 * n
N1 = -(n - 1) * f ** 2 * pp * (q - n)
N2a = (-8 + 8 * p + 8 * q) * n ** 3
N2b = (16 * pp + (-44 + 20 * q) * p + 26 - 24 * q) * n ** 2
N2c = (10 * ppp + (14 * q - 46) * pp + (-40 * q + 66) * p - 28 + 24 * q) * n
N2d = 2 * pp ** 2 + (-13 + 3 * q) * ppp + (-14 * q + 30) * pp
N2e = (-29 + 19 * q) * p + 10 - 8 * q
D1 = q ** 2 * (p2n - 3) ** 2
D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2
return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2
def _betainc_da_n_dq(f, p, q, n):
"""
Derivative of a_n wrt q
"""
if n == 1:
return p * f / (q * (p + 1))
p2n = p + 2 * n
F1 = (p ** 2 * f ** 2 / (q ** 2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2)
D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)
return F1 / D1
def _betainc_db_n_dp(f, p, q, n):
"""
Derivative of b_n wrt p
"""
p2n = p + 2 * n
pp = p ** 2
q4 = 4 * q
p4 = 4 * p
F1 = (p * f / q) * (
(-p4 - q4 + 4) * n ** 2 + (p4 - 4 + q4 - 2 * pp) * n + pp * q
)
D1 = (p2n - 2) ** 2 * p2n ** 2
return F1 / D1
def _betainc_db_n_dq(f, p, q, n):
"""
Derivative of b_n wrt to q
"""
p2n = p + 2 * n
return -(p ** 2 * f) / (q * (p2n - 2) * p2n)
# Input validation
if not (0 <= x <= 1) or p < 0 or q < 0:
return np.nan
if x > (p / (p + q)):
return -self.impl(q, p, 1 - x, not wrtp)
min_iters = 3
max_iters = 200
err_threshold = 1e-12
derivative_old = 0
Am2, Am1 = 1, 1
Bm2, Bm1 = 0, 1
dAm2, dAm1 = 0, 0
dBm2, dBm1 = 0, 0
f = (q * x) / (p * (1 - x))
K = np.exp(
p * np.log(x)
+ (q - 1) * np.log1p(-x)
- np.log(p)
- scipy.special.betaln(p, q)
)
if wrtp:
dK = (
np.log(x)
- 1 / p
+ scipy.special.digamma(p + q)
- scipy.special.digamma(p)
)
else:
dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q)
for n in range(1, max_iters + 1):
a_n_ = _betainc_a_n(f, p, q, n)
b_n_ = _betainc_b_n(f, p, q, n)
if wrtp:
da_n = _betainc_da_n_dp(f, p, q, n)
db_n = _betainc_db_n_dp(f, p, q, n)
else:
da_n = _betainc_da_n_dq(f, p, q, n)
db_n = _betainc_db_n_dq(f, p, q, n)
A = a_n_ * Am2 + b_n_ * Am1
B = a_n_ * Bm2 + b_n_ * Bm1
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1
Am2, Am1 = Am1, A
Bm2, Bm1 = Bm1, B
dAm2, dAm1 = dAm1, dA
dBm2, dBm1 = dBm1, dB
if n < min_iters - 1:
continue
F1 = A / B
F2 = (dA - F1 * dB) / B
derivative = K * (F1 * dK + F2)
errapx = abs(derivative_old - derivative)
d_errapx = errapx / max(err_threshold, abs(derivative))
derivative_old = derivative
if d_errapx <= err_threshold:
break
if n >= max_iters:
warnings.warn(
f"_betainc_derivative did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan
return derivative
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
......@@ -323,6 +323,11 @@ def log1mexp_inplace(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""
@scalar_elemwise
def betainc_inplace(a, b, x):
"""Regularized incomplete beta function"""
@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
......
......@@ -1429,6 +1429,11 @@ def log1mexp(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""
@scalar_elemwise
def betainc(a, b, x):
"""Regularized incomplete beta function"""
@scalar_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
......@@ -2909,6 +2914,7 @@ __all__ = [
"softplus",
"log1pexp",
"log1mexp",
"betainc",
"real",
"imag",
"angle",
......
import numpy as np
import scipy.special as sp
import aesara.tensor as aet
from aesara import function
from aesara.compile.mode import Mode
from aesara.graph.fg import FunctionGraph
from aesara.link.c.basic import CLinker
from aesara.scalar.math import gammainc, gammaincc, gammal, gammau
from aesara.scalar.math import betainc, betainc_der, gammainc, gammaincc, gammal, gammau
def test_gammainc_nan():
......@@ -44,3 +47,21 @@ def test_gammau_nan():
assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1))
def test_betainc():
a, b, x = aet.scalars("a", "b", "x")
res = betainc(a, b, x)
test_func = function([a, b, x], res, mode=Mode("py"))
assert np.isclose(test_func(15, 10, 0.7), sp.betainc(15, 10, 0.7))
def test_betainc_derivative_nan():
a, b, x = aet.scalars("a", "b", "x")
res = betainc_der(a, b, x, True)
test_func = function([a, b, x], res, mode=Mode("py"))
assert not np.isnan(test_func(1, 1, 1))
assert np.isnan(test_func(1, 1, -1))
assert np.isnan(test_func(1, 1, 2))
assert np.isnan(test_func(1, -1, 1))
assert np.isnan(test_func(1, 1, -1))
......@@ -9,6 +9,7 @@ from functools import partial
import scipy.special
import scipy.stats
from aesara import function
from aesara import tensor as aet
from aesara.compile.mode import get_default_mode
from aesara.configdefaults import config
......@@ -603,3 +604,101 @@ TestLog1mexpInplaceBroadcast = makeBroadcastTester(
def test_deprecated_module():
with pytest.warns(DeprecationWarning):
import aesara.scalar.basic_scipy # noqa: F401
_good_broadcast_ternary_betainc = dict(
normal=(
random_ranged(0, 1000, (2, 3)),
random_ranged(0, 1000, (2, 3)),
random_ranged(0, 1, (2, 3)),
),
)
TestBetaincBroadcast = makeBroadcastTester(
op=aet.betainc,
expected=scipy.special.betainc,
good=_good_broadcast_ternary_betainc,
grad=_good_broadcast_ternary_betainc,
)
TestBetaincInplaceBroadcast = makeBroadcastTester(
op=inplace.betainc_inplace,
expected=scipy.special.betainc,
good=_good_broadcast_ternary_betainc,
grad=_good_broadcast_ternary_betainc,
inplace=True,
)
class TestBetaIncGrad:
def test_stan_grad_partial(self):
# This test combines the following STAN tests:
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_dda_test.cpp
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddb_test.cpp
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddz_test.cpp
a, b, z = aet.scalars("a", "b", "z")
betainc_out = aet.betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b, z])
f_grad = function([a, b, z], betainc_grad)
decimal_precision = 7 if config.floatX == "float64" else 3
for test_a, test_b, test_z, expected_dda, expected_ddb, expected_ddz in (
(1.5, 1.25, 0.001, -0.00028665637, 4.41357328e-05, 0.063300692),
(1.5, 1.25, 0.5, -0.26038693947, 0.29301795, 1.1905416),
(1.5, 1.25, 0.6, -0.23806757, 0.32279575, 1.23341068),
(1.5, 1.25, 0.999, -0.00022264493, 0.0018969609, 0.35587692),
(15000, 1.25, 0.001, 0, 0, 0),
(15000, 1.25, 0.5, 0, 0, 0),
(15000, 1.25, 0.6, 0, 0, 0),
(15000, 1.25, 0.999, -6.59543226e-10, 2.00849793e-06, 0.009898182),
(1.5, 12500, 0.001, -3.93756641e-05, 1.47821755e-09, 0.1848717),
(1.5, 12500, 0.5, 0, 0, 0),
(1.5, 12500, 0.6, 0, 0, 0),
(1.5, 12500, 0.999, 0, 0, 0),
(15000, 12500, 0.001, 0, 0, 0),
(15000, 12500, 0.5, -8.72102443e-53, 9.55282792e-53, 5.01131256e-48),
(15000, 12500, 0.6, -4.085621e-14, -5.5067062e-14, 1.15135267e-71),
(15000, 12500, 0.999, 0, 0, 0),
):
np.testing.assert_almost_equal(
f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb, expected_ddz],
decimal=decimal_precision,
)
def test_boik_robison_cox(self):
# This test compares against the tabulated values in:
# Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
# Journal of Statistical Software, 3(1), 1-20.
a, b, z = aet.scalars("a", "b", "z")
betainc_out = aet.betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b])
f_grad = function([a, b, z], betainc_grad)
for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04),
(1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04),
(1000.0, 1000.0, 0.5, -8.9224793e-03, 8.9224793e-03),
(1000.0, 1000.0, 0.55, -3.6713108e-07, 4.0584118e-07),
):
np.testing.assert_almost_equal(
f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb],
)
def test_beta_inc_stan_grad_combined(self):
# This test replicates the following STAN test:
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_reg_inc_beta_test.cpp
a, b, z = aet.scalars("a", "b", "z")
betainc_out = aet.betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b])
f_grad = function([a, b, z], betainc_grad)
for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.0, 1.0, 1.0, 0, np.nan),
(1.0, 1.0, 0.4, -0.36651629, 0.30649537),
):
np.testing.assert_allclose(
f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb]
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论