提交 39d37df6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for betainc gradient

上级 3041831e
......@@ -14,10 +14,9 @@ import scipy.stats
from pytensor.configdefaults import config
from pytensor.gradient import grad_not_implemented
from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
from pytensor.scalar.basic import abs as scalar_abs
from pytensor.scalar.basic import (
BinaryScalarOp,
ScalarOp,
UnaryScalarOp,
as_scalar,
complex_types,
constant,
......@@ -27,9 +26,12 @@ from pytensor.scalar.basic import (
expm1,
float64,
float_types,
identity,
isinf,
log,
log1p,
reciprocal,
scalar_maximum,
sqrt,
switch,
true_div,
......@@ -1329,8 +1331,8 @@ class BetaInc(ScalarOp):
(gz,) = grads
return [
gz * betainc_der(a, b, x, True),
gz * betainc_der(a, b, x, False),
gz * betainc_grad(a, b, x, True),
gz * betainc_grad(a, b, x, False),
gz
* exp(
log1p(-x) * (b - 1)
......@@ -1346,28 +1348,28 @@ class BetaInc(ScalarOp):
betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")
class BetaIncDer(ScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the first
argument (alpha) or the second argument (beta), depending on whether the
fourth argument to betainc_der is `True` or `False`, respectively.
def betainc_grad(p, q, x, wrtp: bool):
"""Gradient of the regularized lower gamma function (P) wrt to the first
argument (k, a.k.a. alpha).
Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
Journal of Statistical Software, 3(1), 1-20.
Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
"""
nin = 4
def _betainc_der(p, q, x, wrtp, skip_loop):
dtype = upcast(p.type.dtype, q.type.dtype, x.type.dtype, "float32")
def betaln(a, b):
return gammaln(a) + (gammaln(b) - gammaln(a + b))
def impl(self, p, q, x, wrtp):
def _betainc_a_n(f, p, q, n):
"""
Numerator (a_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""
if n == 1:
return p * f * (q - 1) / (q * (p + 1))
p2n = p + 2 * n
F1 = p**2 * f**2 * (n - 1) / (q**2)
F2 = (
......@@ -1377,7 +1379,11 @@ class BetaIncDer(ScalarOp):
/ ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1))
)
return F1 * F2
return switch(
eq(n, 1),
p * f * (q - 1) / (q * (p + 1)),
F1 * F2,
)
def _betainc_b_n(f, p, q, n):
"""
......@@ -1397,9 +1403,6 @@ class BetaIncDer(ScalarOp):
Derivative of a_n wrt p
"""
if n == 1:
return -p * f * (q - 1) / (q * (p + 1) ** 2)
pp = p**2
ppp = pp * p
p2n = p + 2 * n
......@@ -1414,20 +1417,25 @@ class BetaIncDer(ScalarOp):
D1 = q**2 * (p2n - 3) ** 2
D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2
return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2
return switch(
eq(n, 1),
-p * f * (q - 1) / (q * (p + 1) ** 2),
(N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2,
)
def _betainc_da_n_dq(f, p, q, n):
"""
Derivative of a_n wrt q
"""
if n == 1:
return p * f / (q * (p + 1))
p2n = p + 2 * n
F1 = (p**2 * f**2 / (q**2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2)
D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)
return F1 / D1
return switch(
eq(n, 1),
p * f / (q * (p + 1)),
F1 / D1,
)
def _betainc_db_n_dp(f, p, q, n):
"""
......@@ -1452,42 +1460,44 @@ class BetaIncDer(ScalarOp):
p2n = p + 2 * n
return -(p**2 * f) / (q * (p2n - 2) * p2n)
# Input validation
if not (0 <= x <= 1) or p < 0 or q < 0:
return np.nan
if x > (p / (p + q)):
return -self.impl(q, p, 1 - x, not wrtp)
min_iters = 3
max_iters = 200
err_threshold = 1e-12
derivative_old = 0
min_iters = np.array(3, dtype="int32")
max_iters = switch(
skip_loop, np.array(0, dtype="int32"), np.array(200, dtype="int32")
)
err_threshold = np.array(1e-12, dtype=config.floatX)
Am2, Am1 = 1, 1
Bm2, Bm1 = 0, 1
dAm2, dAm1 = 0, 0
dBm2, dBm1 = 0, 0
Am2, Am1 = np.array(1, dtype=dtype), np.array(1, dtype=dtype)
Bm2, Bm1 = np.array(0, dtype=dtype), np.array(1, dtype=dtype)
dAm2, dAm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype)
dBm2, dBm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype)
f = (q * x) / (p * (1 - x))
K = np.exp(
p * np.log(x)
+ (q - 1) * np.log1p(-x)
- np.log(p)
- scipy.special.betaln(p, q)
)
K = exp(p * log(x) + (q - 1) * log1p(-x) - log(p) - betaln(p, q))
if wrtp:
dK = (
np.log(x)
- 1 / p
+ scipy.special.digamma(p + q)
- scipy.special.digamma(p)
)
dK = log(x) - reciprocal(p) + psi(p + q) - psi(p)
else:
dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q)
for n in range(1, max_iters + 1):
dK = log1p(-x) + psi(p + q) - psi(q)
derivative = np.array(0, dtype=dtype)
n = np.array(1, dtype="int16") # Enough for 200 max iters
def inner_loop(
derivative,
Am2,
Am1,
Bm2,
Bm1,
dAm2,
dAm1,
dBm2,
dBm1,
n,
f,
p,
q,
K,
dK,
):
a_n_ = _betainc_a_n(f, p, q, n)
b_n_ = _betainc_b_n(f, p, q, n)
if wrtp:
......@@ -1502,36 +1512,53 @@ class BetaIncDer(ScalarOp):
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1
Am2, Am1 = Am1, A
Bm2, Bm1 = Bm1, B
dAm2, dAm1 = dAm1, dA
dBm2, dBm1 = dBm1, dB
if n < min_iters - 1:
continue
Am2, Am1 = identity(Am1), identity(A)
Bm2, Bm1 = identity(Bm1), identity(B)
dAm2, dAm1 = identity(dAm1), identity(dA)
dBm2, dBm1 = identity(dBm1), identity(dB)
F1 = A / B
F2 = (dA - F1 * dB) / B
derivative = K * (F1 * dK + F2)
derivative_new = K * (F1 * dK + F2)
errapx = abs(derivative_old - derivative)
d_errapx = errapx / max(err_threshold, abs(derivative))
derivative_old = derivative
if d_errapx <= err_threshold:
return derivative
errapx = scalar_abs(derivative - derivative_new)
d_errapx = errapx / scalar_maximum(
err_threshold, scalar_abs(derivative_new)
)
warnings.warn(
f"betainc_der did not converge after {n} iterations",
RuntimeWarning,
min_iters_cond = n > (min_iters - 1)
derivative = switch(
min_iters_cond,
derivative_new,
derivative,
)
return np.nan
n += 1
def c_code(self, *args, **kwargs):
raise NotImplementedError()
return (
(derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n),
(d_errapx <= err_threshold) & min_iters_cond,
)
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
constant = [f, p, q, K, dK]
grad = _make_scalar_loop(
max_iters, init, constant, inner_loop, name="betainc_grad"
)
return grad
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
# Input validation
nan_branch = (x < 0) | (x > 1) | (p < 0) | (q < 0)
flip_branch = x > (p / (p + q))
grad = switch(
nan_branch,
np.nan,
switch(
flip_branch,
-_betainc_der(q, p, 1 - x, not wrtp, skip_loop=nan_branch | (~flip_branch)),
_betainc_der(p, q, x, wrtp, skip_loop=nan_branch | flip_branch),
),
)
return grad
class Hyp2F1(ScalarOp):
......
......@@ -8,7 +8,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import CLinker
from pytensor.scalar.math import (
betainc,
betainc_der,
betainc_grad,
gammainc,
gammaincc,
gammal,
......@@ -82,7 +82,7 @@ def test_betainc():
def test_betainc_derivative_nan():
a, b, x = at.scalars("a", "b", "x")
res = betainc_der(a, b, x, True)
res = betainc_grad(a, b, x, True)
test_func = function([a, b, x], res, mode=Mode("py"))
assert not np.isnan(test_func(1, 1, 1))
assert np.isnan(test_func(1, 1, -1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论