提交 3041831e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for gammainc(c) gradients

上级 cd93444e
......@@ -18,8 +18,11 @@ from pytensor.scalar.basic import (
BinaryScalarOp,
ScalarOp,
UnaryScalarOp,
as_scalar,
complex_types,
constant,
discrete_types,
eq,
exp,
expm1,
float64,
......@@ -27,6 +30,7 @@ from pytensor.scalar.basic import (
isinf,
log,
log1p,
sqrt,
switch,
true_div,
upcast,
......@@ -34,6 +38,7 @@ from pytensor.scalar.basic import (
upgrade_to_float64,
upgrade_to_float_no_complex,
)
from pytensor.scalar.loop import ScalarLoop
class Erf(UnaryScalarOp):
......@@ -595,7 +600,7 @@ class GammaInc(BinaryScalarOp):
(k, x) = inputs
(gz,) = grads
return [
gz * gammainc_der(k, x),
gz * gammainc_grad(k, x),
gz * exp(-x + (k - 1) * log(x) - gammaln(k)),
]
......@@ -644,7 +649,7 @@ class GammaIncC(BinaryScalarOp):
(k, x) = inputs
(gz,) = grads
return [
gz * gammaincc_der(k, x),
gz * gammaincc_grad(k, x),
gz * -exp(-x + (k - 1) * log(x) - gammaln(k)),
]
......@@ -675,162 +680,209 @@ class GammaIncC(BinaryScalarOp):
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
class GammaIncDer(BinaryScalarOp):
"""
Gradient of the the regularized lower gamma function (P) wrt to the first
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
init = [as_scalar(x) for x in init]
constant = [as_scalar(x) for x in constant]
# Create dummy types, in case some variables have the same initial form
init_ = [x.type() for x in init]
constant_ = [x.type() for x in constant]
update_, until_ = inner_loop_fn(*init_, *constant_)
op = ScalarLoop(
init=init_,
constant=constant_,
update=update_,
until=until_,
until_condition_failed="warn",
name=name,
)
S, *_ = op(n_steps, *init, *constant)
return S
def gammainc_grad(k, x):
"""Gradient of the regularized lower gamma function (P) wrt to the first
argument (k, a.k.a. alpha).
Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
"""
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
def impl(self, k, x):
if x == 0:
return 0
sqrt_exp = -756 - x**2 + 60 * x
if (
(k < 0.8 and x > 15)
or (k < 12 and x > 30)
or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp))
):
return -GammaIncCDer.st_impl(k, x)
precision = 1e-10
max_iters = int(1e5)
def grad_approx(skip_loop):
precision = np.array(1e-10, dtype=config.floatX)
max_iters = switch(
skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32")
)
log_x = np.log(x)
log_gamma_k_plus_1 = scipy.special.gammaln(k + 1)
log_x = log(x)
log_gamma_k_plus_1 = gammaln(k + 1)
k_plus_n = k
# First loop
k_plus_n = k # Should not overflow unless k > 2,147,383,647
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_a = 0.0
for n in range(0, max_iters + 1):
term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
sum_a += term
sum_a0 = np.array(0.0, dtype=dtype)
if term <= precision:
break
def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):
term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
sum_a += term
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
k_plus_n += 1
if n >= max_iters:
warnings.warn(
f"gammainc_der did not converge after {n} iterations",
RuntimeWarning,
return (
(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n),
(term <= precision),
)
return np.nan
k_plus_n = k
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
constant = [log_x]
sum_a = _make_scalar_loop(
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
)
# Second loop
n = np.array(0, dtype="int32")
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_b = 0.0
for n in range(0, max_iters + 1):
term = np.exp(
k_plus_n * log_x - log_gamma_k_plus_n_plus_1
) * scipy.special.digamma(k_plus_n + 1)
sum_b += term
k_plus_n = k
sum_b0 = np.array(0.0, dtype=dtype)
if term <= precision and n >= 1: # Require at least two iterations
return np.exp(-x) * (log_x * sum_a - sum_b)
def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x):
term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1)
sum_b += term
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
n += 1
k_plus_n += 1
return (
(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n),
# Require at least two iterations
((term <= precision) & (n > 1)),
)
warnings.warn(
f"gammainc_der did not converge after {n} iterations",
RuntimeWarning,
init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n]
constant = [log_x]
sum_b, *_ = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammainc_grad_b"
)
return np.nan
def c_code(self, *args, **kwargs):
raise NotImplementedError()
gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")
class GammaIncCDer(BinaryScalarOp):
"""
Gradient of the the regularized upper gamma function (Q) wrt to the first
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
grad_approx = exp(-x) * (log_x * sum_a - sum_b)
return grad_approx
zero_branch = eq(x, 0)
sqrt_exp = -756 - x**2 + 60 * x
gammaincc_branch = (
((k < 0.8) & (x > 15))
| ((k < 12) & (x > 30))
| ((sqrt_exp > 0) & (k < sqrt(sqrt_exp)))
)
grad = switch(
zero_branch,
0,
switch(
gammaincc_branch,
-gammaincc_grad(k, x, skip_loops=zero_branch | (~gammaincc_branch)),
grad_approx(skip_loop=zero_branch | gammaincc_branch),
),
)
return grad
def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
"""Gradient of the regularized upper gamma function (Q) wrt to the first
argument (k, a.k.a. alpha).
Adapted from STAN `grad_reg_inc_gamma.hpp`
skip_loops is used for faster branching when this function is called by `gammainc_der`
"""
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
@staticmethod
def st_impl(k, x):
gamma_k = scipy.special.gamma(k)
digamma_k = scipy.special.digamma(k)
log_x = np.log(x)
# asymptotic expansion http://dlmf.nist.gov/8.11#E2
if (x >= k) and (x >= 8):
S = 0
k_minus_one_minus_n = k - 1
fac = k_minus_one_minus_n
dfac = 1
xpow = x
gamma_k = gamma(k)
digamma_k = psi(k)
log_x = log(x)
def approx_a(skip_loop):
n_steps = switch(
skip_loop, np.array(0, dtype="int32"), np.array(9, dtype="int32")
)
sum_a0 = np.array(0.0, dtype=dtype)
dfac = np.array(1.0, dtype=dtype)
xpow = x
k_minus_one_minus_n = k - 1
fac = k_minus_one_minus_n
delta = true_div(dfac, xpow)
def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x):
sum_a += delta
xpow *= x
k_minus_one_minus_n -= 1
dfac = k_minus_one_minus_n * dfac + fac
fac *= k_minus_one_minus_n
delta = dfac / xpow
return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), ()
for n in range(1, 10):
k_minus_one_minus_n -= 1
S += delta
xpow *= x
dfac = k_minus_one_minus_n * dfac + fac
fac *= k_minus_one_minus_n
delta = dfac / xpow
if np.isinf(delta):
warnings.warn(
"gammaincc_der did not converge",
RuntimeWarning,
)
return np.nan
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
constant = [x]
sum_a = _make_scalar_loop(
n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a"
)
grad_approx_a = (
gammaincc(k, x) * (log_x - digamma_k)
+ exp(-x + (k - 1) * log_x) * sum_a / gamma_k
)
return grad_approx_a
def approx_b(skip_loop):
max_iters = switch(
skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32")
)
log_precision = np.array(np.log(1e-6), dtype=config.floatX)
sum_b0 = np.array(0.0, dtype=dtype)
log_s = np.array(0.0, dtype=dtype)
s_sign = np.array(1, dtype="int8")
n = np.array(1, dtype="int32")
log_delta = log_s - 2 * log(k)
def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
delta = exp(log_delta)
sum_b += switch(s_sign > 0, delta, -delta)
s_sign = -s_sign
# log will cast >int16 to float64
log_s_inc = log_x - log(n)
if log_s_inc.type.dtype != log_s.type.dtype:
log_s_inc = log_s_inc.astype(log_s.type.dtype)
log_s += log_s_inc
new_log_delta = log_s - 2 * log(n + k)
if new_log_delta.type.dtype != log_delta.type.dtype:
new_log_delta = new_log_delta.astype(log_delta.type.dtype)
log_delta = new_log_delta
n += 1
return (
scipy.special.gammaincc(k, x) * (log_x - digamma_k)
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k
)
# gradient of series expansion http://dlmf.nist.gov/8.7#E3
else:
log_precision = np.log(1e-6)
max_iters = int(1e5)
S = 0
log_s = 0.0
s_sign = 1
log_delta = log_s - 2 * np.log(k)
for n in range(1, max_iters + 1):
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
s_sign = -s_sign
log_s += log_x - np.log(n)
log_delta = log_s - 2 * np.log(n + k)
if np.isinf(log_delta):
warnings.warn(
"gammaincc_der did not converge",
RuntimeWarning,
)
return np.nan
if log_delta <= log_precision:
return (
scipy.special.gammainc(k, x) * (digamma_k - log_x)
+ np.exp(k * log_x) * S / gamma_k
)
warnings.warn(
f"gammaincc_der did not converge after {n} iterations",
RuntimeWarning,
(sum_b, log_s, s_sign, log_delta, n),
log_delta <= log_precision,
)
return np.nan
def impl(self, k, x):
return self.st_impl(k, x)
def c_code(self, *args, **kwargs):
raise NotImplementedError()
gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")
init = [sum_b0, log_s, s_sign, log_delta, n]
constant = [k, log_x]
sum_b = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
)
grad_approx_b = (
gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k
)
return grad_approx_b
branch_a = (x >= k) & (x >= 8)
return switch(
branch_a,
approx_a(skip_loop=~branch_a | skip_loops),
approx_b(skip_loop=branch_a | skip_loops),
)
class GammaU(BinaryScalarOp):
......
......@@ -3,6 +3,8 @@ from contextlib import ExitStack as does_not_warn
import numpy as np
import pytest
from pytensor.gradient import verify_grad
scipy = pytest.importorskip("scipy")
......@@ -11,11 +13,11 @@ from functools import partial
import scipy.special
import scipy.stats
from pytensor import function
from pytensor import function, grad
from pytensor import tensor as at
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor import inplace
from pytensor.tensor import gammaincc, inplace, vector
from tests import unittest_tools as utt
from tests.tensor.utils import (
_good_broadcast_unary_chi2sf,
......@@ -387,6 +389,9 @@ def test_gammainc_ddk_tabulated_values():
gammaincc_ddk = at.grad(gammainc_out, k)
f_grad = function([k, x], gammaincc_ddk)
rtol = 1e-5 if config.floatX == "float64" else 1e-2
atol = 1e-10 if config.floatX == "float64" else 1e-6
for test_k, test_x, expected_ddk in (
(0.0001, 0, 0), # Limit condition
(0.0001, 0.0001, -8.62594024578651),
......@@ -421,10 +426,27 @@ def test_gammainc_ddk_tabulated_values():
(19.0001, 29.7501, -0.007828749832965796),
):
np.testing.assert_allclose(
f_grad(test_k, test_x), expected_ddk, rtol=1e-5, atol=1e-14
f_grad(test_k, test_x), expected_ddk, rtol=rtol, atol=atol
)
def test_gammaincc_ddk_performance(benchmark):
rng = np.random.default_rng(1)
k = vector("k")
x = vector("x")
out = gammaincc(k, x)
grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN")
vals = [
# Values that hit the second branch of the gradient
np.full((1000,), 3.2),
np.full((1000,), 0.01),
]
verify_grad(gammaincc, vals, rng=rng)
benchmark(grad_fn, *vals)
TestGammaUBroadcast = makeBroadcastTester(
op=at.gammau,
expected=expected_gammau,
......@@ -796,7 +818,7 @@ class TestBetaIncGrad:
betainc_out = at.betainc(a, b, z)
betainc_grad = at.grad(betainc_out, [a, b])
f_grad = function([a, b, z], betainc_grad)
decimal = 7 if config.floatX == "float64" else 5
for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04),
(1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04),
......@@ -806,6 +828,7 @@ class TestBetaIncGrad:
np.testing.assert_almost_equal(
f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb],
decimal=decimal,
)
def test_beta_inc_stan_grad_combined(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论