提交 39d37df6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for betainc gradient

上级 3041831e
...@@ -14,10 +14,9 @@ import scipy.stats ...@@ -14,10 +14,9 @@ import scipy.stats
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad_not_implemented from pytensor.gradient import grad_not_implemented
from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
from pytensor.scalar.basic import abs as scalar_abs
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
BinaryScalarOp,
ScalarOp,
UnaryScalarOp,
as_scalar, as_scalar,
complex_types, complex_types,
constant, constant,
...@@ -27,9 +26,12 @@ from pytensor.scalar.basic import ( ...@@ -27,9 +26,12 @@ from pytensor.scalar.basic import (
expm1, expm1,
float64, float64,
float_types, float_types,
identity,
isinf, isinf,
log, log,
log1p, log1p,
reciprocal,
scalar_maximum,
sqrt, sqrt,
switch, switch,
true_div, true_div,
...@@ -1329,8 +1331,8 @@ class BetaInc(ScalarOp): ...@@ -1329,8 +1331,8 @@ class BetaInc(ScalarOp):
(gz,) = grads (gz,) = grads
return [ return [
gz * betainc_der(a, b, x, True), gz * betainc_grad(a, b, x, True),
gz * betainc_der(a, b, x, False), gz * betainc_grad(a, b, x, False),
gz gz
* exp( * exp(
log1p(-x) * (b - 1) log1p(-x) * (b - 1)
...@@ -1346,28 +1348,28 @@ class BetaInc(ScalarOp): ...@@ -1346,28 +1348,28 @@ class BetaInc(ScalarOp):
betainc = BetaInc(upgrade_to_float_no_complex, name="betainc") betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")
class BetaIncDer(ScalarOp): def betainc_grad(p, q, x, wrtp: bool):
""" """Gradient of the regularized lower gamma function (P) wrt to the first
Gradient of the regularized incomplete beta function wrt to the first argument (k, a.k.a. alpha).
argument (alpha) or the second argument (beta), depending on whether the
fourth argument to betainc_der is `True` or `False`, respectively.
Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function. Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
Journal of Statistical Software, 3(1), 1-20.
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
""" """
nin = 4 def _betainc_der(p, q, x, wrtp, skip_loop):
dtype = upcast(p.type.dtype, q.type.dtype, x.type.dtype, "float32")
def betaln(a, b):
return gammaln(a) + (gammaln(b) - gammaln(a + b))
def impl(self, p, q, x, wrtp):
def _betainc_a_n(f, p, q, n): def _betainc_a_n(f, p, q, n):
""" """
Numerator (a_n) of the nth approximant of the continued fraction Numerator (a_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function representation of the regularized incomplete beta function
""" """
if n == 1:
return p * f * (q - 1) / (q * (p + 1))
p2n = p + 2 * n p2n = p + 2 * n
F1 = p**2 * f**2 * (n - 1) / (q**2) F1 = p**2 * f**2 * (n - 1) / (q**2)
F2 = ( F2 = (
...@@ -1377,7 +1379,11 @@ class BetaIncDer(ScalarOp): ...@@ -1377,7 +1379,11 @@ class BetaIncDer(ScalarOp):
/ ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)) / ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1))
) )
return F1 * F2 return switch(
eq(n, 1),
p * f * (q - 1) / (q * (p + 1)),
F1 * F2,
)
def _betainc_b_n(f, p, q, n): def _betainc_b_n(f, p, q, n):
""" """
...@@ -1397,9 +1403,6 @@ class BetaIncDer(ScalarOp): ...@@ -1397,9 +1403,6 @@ class BetaIncDer(ScalarOp):
Derivative of a_n wrt p Derivative of a_n wrt p
""" """
if n == 1:
return -p * f * (q - 1) / (q * (p + 1) ** 2)
pp = p**2 pp = p**2
ppp = pp * p ppp = pp * p
p2n = p + 2 * n p2n = p + 2 * n
...@@ -1414,20 +1417,25 @@ class BetaIncDer(ScalarOp): ...@@ -1414,20 +1417,25 @@ class BetaIncDer(ScalarOp):
D1 = q**2 * (p2n - 3) ** 2 D1 = q**2 * (p2n - 3) ** 2
D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2 D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2
return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2 return switch(
eq(n, 1),
-p * f * (q - 1) / (q * (p + 1) ** 2),
(N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2,
)
def _betainc_da_n_dq(f, p, q, n): def _betainc_da_n_dq(f, p, q, n):
""" """
Derivative of a_n wrt q Derivative of a_n wrt q
""" """
if n == 1:
return p * f / (q * (p + 1))
p2n = p + 2 * n p2n = p + 2 * n
F1 = (p**2 * f**2 / (q**2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2) F1 = (p**2 * f**2 / (q**2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2)
D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1) D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)
return F1 / D1 return switch(
eq(n, 1),
p * f / (q * (p + 1)),
F1 / D1,
)
def _betainc_db_n_dp(f, p, q, n): def _betainc_db_n_dp(f, p, q, n):
""" """
...@@ -1452,42 +1460,44 @@ class BetaIncDer(ScalarOp): ...@@ -1452,42 +1460,44 @@ class BetaIncDer(ScalarOp):
p2n = p + 2 * n p2n = p + 2 * n
return -(p**2 * f) / (q * (p2n - 2) * p2n) return -(p**2 * f) / (q * (p2n - 2) * p2n)
# Input validation min_iters = np.array(3, dtype="int32")
if not (0 <= x <= 1) or p < 0 or q < 0: max_iters = switch(
return np.nan skip_loop, np.array(0, dtype="int32"), np.array(200, dtype="int32")
)
if x > (p / (p + q)): err_threshold = np.array(1e-12, dtype=config.floatX)
return -self.impl(q, p, 1 - x, not wrtp)
min_iters = 3
max_iters = 200
err_threshold = 1e-12
derivative_old = 0
Am2, Am1 = 1, 1 Am2, Am1 = np.array(1, dtype=dtype), np.array(1, dtype=dtype)
Bm2, Bm1 = 0, 1 Bm2, Bm1 = np.array(0, dtype=dtype), np.array(1, dtype=dtype)
dAm2, dAm1 = 0, 0 dAm2, dAm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype)
dBm2, dBm1 = 0, 0 dBm2, dBm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype)
f = (q * x) / (p * (1 - x)) f = (q * x) / (p * (1 - x))
K = np.exp( K = exp(p * log(x) + (q - 1) * log1p(-x) - log(p) - betaln(p, q))
p * np.log(x)
+ (q - 1) * np.log1p(-x)
- np.log(p)
- scipy.special.betaln(p, q)
)
if wrtp: if wrtp:
dK = ( dK = log(x) - reciprocal(p) + psi(p + q) - psi(p)
np.log(x)
- 1 / p
+ scipy.special.digamma(p + q)
- scipy.special.digamma(p)
)
else: else:
dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q) dK = log1p(-x) + psi(p + q) - psi(q)
for n in range(1, max_iters + 1): derivative = np.array(0, dtype=dtype)
n = np.array(1, dtype="int16") # Enough for 200 max iters
def inner_loop(
derivative,
Am2,
Am1,
Bm2,
Bm1,
dAm2,
dAm1,
dBm2,
dBm1,
n,
f,
p,
q,
K,
dK,
):
a_n_ = _betainc_a_n(f, p, q, n) a_n_ = _betainc_a_n(f, p, q, n)
b_n_ = _betainc_b_n(f, p, q, n) b_n_ = _betainc_b_n(f, p, q, n)
if wrtp: if wrtp:
...@@ -1502,36 +1512,53 @@ class BetaIncDer(ScalarOp): ...@@ -1502,36 +1512,53 @@ class BetaIncDer(ScalarOp):
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1 dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1 dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1
Am2, Am1 = Am1, A Am2, Am1 = identity(Am1), identity(A)
Bm2, Bm1 = Bm1, B Bm2, Bm1 = identity(Bm1), identity(B)
dAm2, dAm1 = dAm1, dA dAm2, dAm1 = identity(dAm1), identity(dA)
dBm2, dBm1 = dBm1, dB dBm2, dBm1 = identity(dBm1), identity(dB)
if n < min_iters - 1:
continue
F1 = A / B F1 = A / B
F2 = (dA - F1 * dB) / B F2 = (dA - F1 * dB) / B
derivative = K * (F1 * dK + F2) derivative_new = K * (F1 * dK + F2)
errapx = abs(derivative_old - derivative) errapx = scalar_abs(derivative - derivative_new)
d_errapx = errapx / max(err_threshold, abs(derivative)) d_errapx = errapx / scalar_maximum(
derivative_old = derivative err_threshold, scalar_abs(derivative_new)
)
if d_errapx <= err_threshold:
return derivative
warnings.warn( min_iters_cond = n > (min_iters - 1)
f"betainc_der did not converge after {n} iterations", derivative = switch(
RuntimeWarning, min_iters_cond,
derivative_new,
derivative,
) )
return np.nan n += 1
def c_code(self, *args, **kwargs): return (
raise NotImplementedError() (derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n),
(d_errapx <= err_threshold) & min_iters_cond,
)
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
constant = [f, p, q, K, dK]
grad = _make_scalar_loop(
max_iters, init, constant, inner_loop, name="betainc_grad"
)
return grad
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der") # Input validation
nan_branch = (x < 0) | (x > 1) | (p < 0) | (q < 0)
flip_branch = x > (p / (p + q))
grad = switch(
nan_branch,
np.nan,
switch(
flip_branch,
-_betainc_der(q, p, 1 - x, not wrtp, skip_loop=nan_branch | (~flip_branch)),
_betainc_der(p, q, x, wrtp, skip_loop=nan_branch | flip_branch),
),
)
return grad
class Hyp2F1(ScalarOp): class Hyp2F1(ScalarOp):
......
...@@ -8,7 +8,7 @@ from pytensor.graph.fg import FunctionGraph ...@@ -8,7 +8,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import CLinker from pytensor.link.c.basic import CLinker
from pytensor.scalar.math import ( from pytensor.scalar.math import (
betainc, betainc,
betainc_der, betainc_grad,
gammainc, gammainc,
gammaincc, gammaincc,
gammal, gammal,
...@@ -82,7 +82,7 @@ def test_betainc(): ...@@ -82,7 +82,7 @@ def test_betainc():
def test_betainc_derivative_nan(): def test_betainc_derivative_nan():
a, b, x = at.scalars("a", "b", "x") a, b, x = at.scalars("a", "b", "x")
res = betainc_der(a, b, x, True) res = betainc_grad(a, b, x, True)
test_func = function([a, b, x], res, mode=Mode("py")) test_func = function([a, b, x], res, mode=Mode("py"))
assert not np.isnan(test_func(1, 1, 1)) assert not np.isnan(test_func(1, 1, 1))
assert np.isnan(test_func(1, 1, -1)) assert np.isnan(test_func(1, 1, -1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论