提交 3041831e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for gammainc(c) gradients

上级 cd93444e
...@@ -18,8 +18,11 @@ from pytensor.scalar.basic import ( ...@@ -18,8 +18,11 @@ from pytensor.scalar.basic import (
BinaryScalarOp, BinaryScalarOp,
ScalarOp, ScalarOp,
UnaryScalarOp, UnaryScalarOp,
as_scalar,
complex_types, complex_types,
constant,
discrete_types, discrete_types,
eq,
exp, exp,
expm1, expm1,
float64, float64,
...@@ -27,6 +30,7 @@ from pytensor.scalar.basic import ( ...@@ -27,6 +30,7 @@ from pytensor.scalar.basic import (
isinf, isinf,
log, log,
log1p, log1p,
sqrt,
switch, switch,
true_div, true_div,
upcast, upcast,
...@@ -34,6 +38,7 @@ from pytensor.scalar.basic import ( ...@@ -34,6 +38,7 @@ from pytensor.scalar.basic import (
upgrade_to_float64, upgrade_to_float64,
upgrade_to_float_no_complex, upgrade_to_float_no_complex,
) )
from pytensor.scalar.loop import ScalarLoop
class Erf(UnaryScalarOp): class Erf(UnaryScalarOp):
...@@ -595,7 +600,7 @@ class GammaInc(BinaryScalarOp): ...@@ -595,7 +600,7 @@ class GammaInc(BinaryScalarOp):
(k, x) = inputs (k, x) = inputs
(gz,) = grads (gz,) = grads
return [ return [
gz * gammainc_der(k, x), gz * gammainc_grad(k, x),
gz * exp(-x + (k - 1) * log(x) - gammaln(k)), gz * exp(-x + (k - 1) * log(x) - gammaln(k)),
] ]
...@@ -644,7 +649,7 @@ class GammaIncC(BinaryScalarOp): ...@@ -644,7 +649,7 @@ class GammaIncC(BinaryScalarOp):
(k, x) = inputs (k, x) = inputs
(gz,) = grads (gz,) = grads
return [ return [
gz * gammaincc_der(k, x), gz * gammaincc_grad(k, x),
gz * -exp(-x + (k - 1) * log(x) - gammaln(k)), gz * -exp(-x + (k - 1) * log(x) - gammaln(k)),
] ]
...@@ -675,162 +680,209 @@ class GammaIncC(BinaryScalarOp): ...@@ -675,162 +680,209 @@ class GammaIncC(BinaryScalarOp):
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
class GammaIncDer(BinaryScalarOp): def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
""" init = [as_scalar(x) for x in init]
Gradient of the the regularized lower gamma function (P) wrt to the first constant = [as_scalar(x) for x in constant]
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp` # Create dummy types, in case some variables have the same initial form
init_ = [x.type() for x in init]
constant_ = [x.type() for x in constant]
update_, until_ = inner_loop_fn(*init_, *constant_)
op = ScalarLoop(
init=init_,
constant=constant_,
update=update_,
until=until_,
until_condition_failed="warn",
name=name,
)
S, *_ = op(n_steps, *init, *constant)
return S
def gammainc_grad(k, x):
"""Gradient of the regularized lower gamma function (P) wrt to the first
argument (k, a.k.a. alpha).
Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions. Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481. ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
""" """
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
def impl(self, k, x): def grad_approx(skip_loop):
if x == 0: precision = np.array(1e-10, dtype=config.floatX)
return 0 max_iters = switch(
skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32")
sqrt_exp = -756 - x**2 + 60 * x )
if (
(k < 0.8 and x > 15)
or (k < 12 and x > 30)
or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp))
):
return -GammaIncCDer.st_impl(k, x)
precision = 1e-10
max_iters = int(1e5)
log_x = np.log(x) log_x = log(x)
log_gamma_k_plus_1 = scipy.special.gammaln(k + 1) log_gamma_k_plus_1 = gammaln(k + 1)
k_plus_n = k # First loop
k_plus_n = k # Should not overflow unless k > 2,147,383,647
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1 log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_a = 0.0 sum_a0 = np.array(0.0, dtype=dtype)
for n in range(0, max_iters + 1):
term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
sum_a += term
if term <= precision: def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):
break term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
sum_a += term
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
k_plus_n += 1 k_plus_n += 1
return (
if n >= max_iters: (sum_a, log_gamma_k_plus_n_plus_1, k_plus_n),
warnings.warn( (term <= precision),
f"gammainc_der did not converge after {n} iterations",
RuntimeWarning,
) )
return np.nan
k_plus_n = k init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
constant = [log_x]
sum_a = _make_scalar_loop(
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
)
# Second loop
n = np.array(0, dtype="int32")
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1 log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_b = 0.0 k_plus_n = k
for n in range(0, max_iters + 1): sum_b0 = np.array(0.0, dtype=dtype)
term = np.exp(
k_plus_n * log_x - log_gamma_k_plus_n_plus_1
) * scipy.special.digamma(k_plus_n + 1)
sum_b += term
if term <= precision and n >= 1: # Require at least two iterations def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x):
return np.exp(-x) * (log_x * sum_a - sum_b) term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1)
sum_b += term
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
n += 1
k_plus_n += 1 k_plus_n += 1
return (
(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n),
# Require at least two iterations
((term <= precision) & (n > 1)),
)
warnings.warn( init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n]
f"gammainc_der did not converge after {n} iterations", constant = [log_x]
RuntimeWarning, sum_b, *_ = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammainc_grad_b"
) )
return np.nan
def c_code(self, *args, **kwargs):
raise NotImplementedError()
gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")
grad_approx = exp(-x) * (log_x * sum_a - sum_b)
class GammaIncCDer(BinaryScalarOp): return grad_approx
"""
Gradient of the the regularized upper gamma function (Q) wrt to the first zero_branch = eq(x, 0)
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp` sqrt_exp = -756 - x**2 + 60 * x
gammaincc_branch = (
((k < 0.8) & (x > 15))
| ((k < 12) & (x > 30))
| ((sqrt_exp > 0) & (k < sqrt(sqrt_exp)))
)
grad = switch(
zero_branch,
0,
switch(
gammaincc_branch,
-gammaincc_grad(k, x, skip_loops=zero_branch | (~gammaincc_branch)),
grad_approx(skip_loop=zero_branch | gammaincc_branch),
),
)
return grad
def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
"""Gradient of the regularized upper gamma function (Q) wrt to the first
argument (k, a.k.a. alpha).
Adapted from STAN `grad_reg_inc_gamma.hpp`
skip_loops is used for faster branching when this function is called by `gammainc_der`
""" """
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
@staticmethod gamma_k = gamma(k)
def st_impl(k, x): digamma_k = psi(k)
gamma_k = scipy.special.gamma(k) log_x = log(x)
digamma_k = scipy.special.digamma(k)
log_x = np.log(x) def approx_a(skip_loop):
n_steps = switch(
# asymptotic expansion http://dlmf.nist.gov/8.11#E2 skip_loop, np.array(0, dtype="int32"), np.array(9, dtype="int32")
if (x >= k) and (x >= 8): )
S = 0 sum_a0 = np.array(0.0, dtype=dtype)
k_minus_one_minus_n = k - 1 dfac = np.array(1.0, dtype=dtype)
fac = k_minus_one_minus_n xpow = x
dfac = 1 k_minus_one_minus_n = k - 1
xpow = x fac = k_minus_one_minus_n
delta = true_div(dfac, xpow)
def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x):
sum_a += delta
xpow *= x
k_minus_one_minus_n -= 1
dfac = k_minus_one_minus_n * dfac + fac
fac *= k_minus_one_minus_n
delta = dfac / xpow delta = dfac / xpow
return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), ()
for n in range(1, 10): init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
k_minus_one_minus_n -= 1 constant = [x]
S += delta sum_a = _make_scalar_loop(
xpow *= x n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a"
dfac = k_minus_one_minus_n * dfac + fac )
fac *= k_minus_one_minus_n grad_approx_a = (
delta = dfac / xpow gammaincc(k, x) * (log_x - digamma_k)
if np.isinf(delta): + exp(-x + (k - 1) * log_x) * sum_a / gamma_k
warnings.warn( )
"gammaincc_der did not converge", return grad_approx_a
RuntimeWarning,
)
return np.nan
def approx_b(skip_loop):
max_iters = switch(
skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32")
)
log_precision = np.array(np.log(1e-6), dtype=config.floatX)
sum_b0 = np.array(0.0, dtype=dtype)
log_s = np.array(0.0, dtype=dtype)
s_sign = np.array(1, dtype="int8")
n = np.array(1, dtype="int32")
log_delta = log_s - 2 * log(k)
def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
delta = exp(log_delta)
sum_b += switch(s_sign > 0, delta, -delta)
s_sign = -s_sign
# log will cast >int16 to float64
log_s_inc = log_x - log(n)
if log_s_inc.type.dtype != log_s.type.dtype:
log_s_inc = log_s_inc.astype(log_s.type.dtype)
log_s += log_s_inc
new_log_delta = log_s - 2 * log(n + k)
if new_log_delta.type.dtype != log_delta.type.dtype:
new_log_delta = new_log_delta.astype(log_delta.type.dtype)
log_delta = new_log_delta
n += 1
return ( return (
scipy.special.gammaincc(k, x) * (log_x - digamma_k) (sum_b, log_s, s_sign, log_delta, n),
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k log_delta <= log_precision,
)
# gradient of series expansion http://dlmf.nist.gov/8.7#E3
else:
log_precision = np.log(1e-6)
max_iters = int(1e5)
S = 0
log_s = 0.0
s_sign = 1
log_delta = log_s - 2 * np.log(k)
for n in range(1, max_iters + 1):
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
s_sign = -s_sign
log_s += log_x - np.log(n)
log_delta = log_s - 2 * np.log(n + k)
if np.isinf(log_delta):
warnings.warn(
"gammaincc_der did not converge",
RuntimeWarning,
)
return np.nan
if log_delta <= log_precision:
return (
scipy.special.gammainc(k, x) * (digamma_k - log_x)
+ np.exp(k * log_x) * S / gamma_k
)
warnings.warn(
f"gammaincc_der did not converge after {n} iterations",
RuntimeWarning,
) )
return np.nan
def impl(self, k, x):
return self.st_impl(k, x)
def c_code(self, *args, **kwargs):
raise NotImplementedError()
init = [sum_b0, log_s, s_sign, log_delta, n]
gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der") constant = [k, log_x]
sum_b = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
)
grad_approx_b = (
gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k
)
return grad_approx_b
branch_a = (x >= k) & (x >= 8)
return switch(
branch_a,
approx_a(skip_loop=~branch_a | skip_loops),
approx_b(skip_loop=branch_a | skip_loops),
)
class GammaU(BinaryScalarOp): class GammaU(BinaryScalarOp):
......
...@@ -3,6 +3,8 @@ from contextlib import ExitStack as does_not_warn ...@@ -3,6 +3,8 @@ from contextlib import ExitStack as does_not_warn
import numpy as np import numpy as np
import pytest import pytest
from pytensor.gradient import verify_grad
scipy = pytest.importorskip("scipy") scipy = pytest.importorskip("scipy")
...@@ -11,11 +13,11 @@ from functools import partial ...@@ -11,11 +13,11 @@ from functools import partial
import scipy.special import scipy.special
import scipy.stats import scipy.stats
from pytensor import function from pytensor import function, grad
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor import inplace from pytensor.tensor import gammaincc, inplace, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.utils import ( from tests.tensor.utils import (
_good_broadcast_unary_chi2sf, _good_broadcast_unary_chi2sf,
...@@ -387,6 +389,9 @@ def test_gammainc_ddk_tabulated_values(): ...@@ -387,6 +389,9 @@ def test_gammainc_ddk_tabulated_values():
gammaincc_ddk = at.grad(gammainc_out, k) gammaincc_ddk = at.grad(gammainc_out, k)
f_grad = function([k, x], gammaincc_ddk) f_grad = function([k, x], gammaincc_ddk)
rtol = 1e-5 if config.floatX == "float64" else 1e-2
atol = 1e-10 if config.floatX == "float64" else 1e-6
for test_k, test_x, expected_ddk in ( for test_k, test_x, expected_ddk in (
(0.0001, 0, 0), # Limit condition (0.0001, 0, 0), # Limit condition
(0.0001, 0.0001, -8.62594024578651), (0.0001, 0.0001, -8.62594024578651),
...@@ -421,10 +426,27 @@ def test_gammainc_ddk_tabulated_values(): ...@@ -421,10 +426,27 @@ def test_gammainc_ddk_tabulated_values():
(19.0001, 29.7501, -0.007828749832965796), (19.0001, 29.7501, -0.007828749832965796),
): ):
np.testing.assert_allclose( np.testing.assert_allclose(
f_grad(test_k, test_x), expected_ddk, rtol=1e-5, atol=1e-14 f_grad(test_k, test_x), expected_ddk, rtol=rtol, atol=atol
) )
def test_gammaincc_ddk_performance(benchmark):
rng = np.random.default_rng(1)
k = vector("k")
x = vector("x")
out = gammaincc(k, x)
grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN")
vals = [
# Values that hit the second branch of the gradient
np.full((1000,), 3.2),
np.full((1000,), 0.01),
]
verify_grad(gammaincc, vals, rng=rng)
benchmark(grad_fn, *vals)
TestGammaUBroadcast = makeBroadcastTester( TestGammaUBroadcast = makeBroadcastTester(
op=at.gammau, op=at.gammau,
expected=expected_gammau, expected=expected_gammau,
...@@ -796,7 +818,7 @@ class TestBetaIncGrad: ...@@ -796,7 +818,7 @@ class TestBetaIncGrad:
betainc_out = at.betainc(a, b, z) betainc_out = at.betainc(a, b, z)
betainc_grad = at.grad(betainc_out, [a, b]) betainc_grad = at.grad(betainc_out, [a, b])
f_grad = function([a, b, z], betainc_grad) f_grad = function([a, b, z], betainc_grad)
decimal = 7 if config.floatX == "float64" else 5
for test_a, test_b, test_z, expected_dda, expected_ddb in ( for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04), (1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04),
(1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04), (1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04),
...@@ -806,6 +828,7 @@ class TestBetaIncGrad: ...@@ -806,6 +828,7 @@ class TestBetaIncGrad:
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
f_grad(test_a, test_b, test_z), f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb], [expected_dda, expected_ddb],
decimal=decimal,
) )
def test_beta_inc_stan_grad_combined(self): def test_beta_inc_stan_grad_combined(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论