提交 c42b56ab authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Add derivatives of gammainc(c)

上级 f1444292
...@@ -556,6 +556,14 @@ class GammaInc(BinaryScalarOp): ...@@ -556,6 +556,14 @@ class GammaInc(BinaryScalarOp):
def impl(self, k, x): def impl(self, k, x):
return GammaInc.st_impl(k, x) return GammaInc.st_impl(k, x)
def grad(self, inputs, grads):
(k, x) = inputs
(gz,) = grads
return [
gz * gammainc_der(k, x),
gz * exp(-x + (k - 1) * log(x) - gammaln(k)),
]
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f:
raw = f.read() raw = f.read()
...@@ -597,6 +605,14 @@ class GammaIncC(BinaryScalarOp): ...@@ -597,6 +605,14 @@ class GammaIncC(BinaryScalarOp):
def impl(self, k, x): def impl(self, k, x):
return GammaIncC.st_impl(k, x) return GammaIncC.st_impl(k, x)
def grad(self, inputs, grads):
(k, x) = inputs
(gz,) = grads
return [
gz * gammaincc_der(k, x),
gz * -exp(-x + (k - 1) * log(x) - gammaln(k)),
]
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f:
raw = f.read() raw = f.read()
...@@ -624,6 +640,159 @@ class GammaIncC(BinaryScalarOp): ...@@ -624,6 +640,159 @@ class GammaIncC(BinaryScalarOp):
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
class GammaIncDer(BinaryScalarOp):
"""
Gradient of the the regularized lower gamma function (P) wrt to the first
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
"""
def impl(self, k, x):
if x == 0:
return 0
sqrt_exp = -756 - x ** 2 + 60 * x
if (
(k < 0.8 and x > 15)
or (k < 12 and x > 30)
or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp))
):
return -GammaIncCDer.st_impl(k, x)
precision = 1e-10
max_iters = int(1e5)
log_x = np.log(x)
log_gamma_k_plus_1 = scipy.special.gammaln(k + 1)
k_plus_n = k
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_a = 0.0
for n in range(0, max_iters + 1):
term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
sum_a += term
if term <= precision:
break
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
k_plus_n += 1
if n >= max_iters:
warnings.warn(
f"gammainc_der did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan
k_plus_n = k
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_b = 0.0
for n in range(0, max_iters + 1):
term = np.exp(
k_plus_n * log_x - log_gamma_k_plus_n_plus_1
) * scipy.special.digamma(k_plus_n + 1)
sum_b += term
if term <= precision and n >= 1: # Require at least two iterations
return np.exp(-x) * (log_x * sum_a - sum_b)
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
k_plus_n += 1
warnings.warn(
f"gammainc_der did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan
gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")
class GammaIncCDer(BinaryScalarOp):
"""
Gradient of the the regularized upper gamma function (Q) wrt to the first
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
"""
@staticmethod
def st_impl(k, x):
gamma_k = scipy.special.gamma(k)
digamma_k = scipy.special.digamma(k)
log_x = np.log(x)
# asymptotic expansion http://dlmf.nist.gov/8.11#E2
if (x >= k) and (x >= 8):
S = 0
k_minus_one_minus_n = k - 1
fac = k_minus_one_minus_n
dfac = 1
xpow = x
delta = dfac / xpow
for n in range(1, 10):
k_minus_one_minus_n -= 1
S += delta
xpow *= x
dfac = k_minus_one_minus_n * dfac + fac
fac *= k_minus_one_minus_n
delta = dfac / xpow
if np.isinf(delta):
warnings.warn(
"gammaincc_der did not converge",
RuntimeWarning,
)
return np.nan
return (
scipy.special.gammaincc(k, x) * (log_x - digamma_k)
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k
)
# gradient of series expansion http://dlmf.nist.gov/8.7#E3
else:
log_precision = np.log(1e-6)
max_iters = int(1e5)
S = 0
log_s = 0.0
s_sign = 1
log_delta = log_s - 2 * np.log(k)
for n in range(1, max_iters + 1):
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
s_sign = -s_sign
log_s += log_x - np.log(n)
log_delta = log_s - 2 * np.log(n + k)
if np.isinf(log_delta):
warnings.warn(
"gammaincc_der did not converge",
RuntimeWarning,
)
return np.nan
if log_delta <= log_precision:
return (
scipy.special.gammainc(k, x) * (digamma_k - log_x)
+ np.exp(k * log_x) * S / gamma_k
)
warnings.warn(
f"gammaincc_der did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan
def impl(self, k, x):
return self.st_impl(k, x)
gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")
class GammaU(BinaryScalarOp): class GammaU(BinaryScalarOp):
""" """
compute the upper incomplete gamma function. compute the upper incomplete gamma function.
......
...@@ -272,10 +272,19 @@ _good_broadcast_binary_gamma = dict( ...@@ -272,10 +272,19 @@ _good_broadcast_binary_gamma = dict(
), ),
) )
_good_broadcast_binary_gamma_grad = dict(
normal=_good_broadcast_binary_gamma["normal"],
specific_branches=(
np.array([0.7, 11.0, 19.0]),
np.array([16.0, 31.0, 3.0]),
),
)
TestGammaIncBroadcast = makeBroadcastTester( TestGammaIncBroadcast = makeBroadcastTester(
op=aet.gammainc, op=aet.gammainc,
expected=expected_gammainc, expected=expected_gammainc,
good=_good_broadcast_binary_gamma, good=_good_broadcast_binary_gamma,
grad=_good_broadcast_binary_gamma_grad,
eps=2e-8, eps=2e-8,
mode=mode_no_scipy, mode=mode_no_scipy,
) )
...@@ -293,6 +302,7 @@ TestGammaInccBroadcast = makeBroadcastTester( ...@@ -293,6 +302,7 @@ TestGammaInccBroadcast = makeBroadcastTester(
op=aet.gammaincc, op=aet.gammaincc,
expected=expected_gammaincc, expected=expected_gammaincc,
good=_good_broadcast_binary_gamma, good=_good_broadcast_binary_gamma,
grad=_good_broadcast_binary_gamma_grad,
eps=2e-8, eps=2e-8,
mode=mode_no_scipy, mode=mode_no_scipy,
) )
...@@ -306,6 +316,53 @@ TestGammaInccInplaceBroadcast = makeBroadcastTester( ...@@ -306,6 +316,53 @@ TestGammaInccInplaceBroadcast = makeBroadcastTester(
inplace=True, inplace=True,
) )
def test_gammainc_ddk_tabulated_values():
# This test replicates part of the old STAN test:
# https://github.com/stan-dev/math/blob/21333bb70b669a1bd54d444ecbe1258078d33153/test/unit/math/prim/scal/fun/grad_reg_lower_inc_gamma_test.cpp
k, x = aet.scalars("k", "x")
gammainc_out = aet.gammainc(k, x)
gammaincc_ddk = aet.grad(gammainc_out, k)
f_grad = function([k, x], gammaincc_ddk)
for test_k, test_x, expected_ddk in (
(0.0001, 0, 0), # Limit condition
(0.0001, 0.0001, -8.62594024578651),
(0.0001, 6.2501, -0.0002705821702813008),
(0.0001, 12.5001, -2.775406821933887e-7),
(0.0001, 18.7501, -3.653379783274905e-10),
(0.0001, 25.0001, -5.352425240798134e-13),
(0.0001, 29.7501, -3.912723010174313e-15),
(4.7501, 0.0001, 0),
(4.7501, 6.2501, -0.1330287013623819),
(4.7501, 12.5001, -0.004712176128251421),
(4.7501, 18.7501, -0.00004898939126595217),
(4.7501, 25.0001, -3.098781566343336e-7),
(4.7501, 29.7501, -5.478399030091586e-9),
(9.5001, 0.0001, -5.869126325643798e-15),
(9.5001, 6.2501, -0.07717967485372858),
(9.5001, 12.5001, -0.07661095137424883),
(9.5001, 18.7501, -0.005594043337407605),
(9.5001, 25.0001, -0.0001410123206233104),
(9.5001, 29.7501, -5.75023943432906e-6),
(14.2501, 0.0001, -7.24495484418588e-15),
(14.2501, 6.2501, -0.003689474744087815),
(14.2501, 12.5001, -0.1008796179460247),
(14.2501, 18.7501, -0.05124664255610913),
(14.2501, 25.0001, -0.005115177188580634),
(14.2501, 29.7501, -0.0004793406401524598),
(19.0001, 0.0001, -8.26027539153394e-15),
(19.0001, 6.2501, -0.00003509660448733015),
(19.0001, 12.5001, -0.02624562607393565),
(19.0001, 18.7501, -0.0923829735092193),
(19.0001, 25.0001, -0.03641281853907181),
(19.0001, 29.7501, -0.007828749832965796),
):
np.testing.assert_allclose(
f_grad(test_k, test_x), expected_ddk, rtol=1e-5, atol=1e-14
)
TestGammaUBroadcast = makeBroadcastTester( TestGammaUBroadcast = makeBroadcastTester(
op=aet.gammau, op=aet.gammau,
expected=expected_gammau, expected=expected_gammau,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论