提交 3041831e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for gammainc(c) gradients

上级 cd93444e
...@@ -18,8 +18,11 @@ from pytensor.scalar.basic import ( ...@@ -18,8 +18,11 @@ from pytensor.scalar.basic import (
BinaryScalarOp, BinaryScalarOp,
ScalarOp, ScalarOp,
UnaryScalarOp, UnaryScalarOp,
as_scalar,
complex_types, complex_types,
constant,
discrete_types, discrete_types,
eq,
exp, exp,
expm1, expm1,
float64, float64,
...@@ -27,6 +30,7 @@ from pytensor.scalar.basic import ( ...@@ -27,6 +30,7 @@ from pytensor.scalar.basic import (
isinf, isinf,
log, log,
log1p, log1p,
sqrt,
switch, switch,
true_div, true_div,
upcast, upcast,
...@@ -34,6 +38,7 @@ from pytensor.scalar.basic import ( ...@@ -34,6 +38,7 @@ from pytensor.scalar.basic import (
upgrade_to_float64, upgrade_to_float64,
upgrade_to_float_no_complex, upgrade_to_float_no_complex,
) )
from pytensor.scalar.loop import ScalarLoop
class Erf(UnaryScalarOp): class Erf(UnaryScalarOp):
...@@ -595,7 +600,7 @@ class GammaInc(BinaryScalarOp): ...@@ -595,7 +600,7 @@ class GammaInc(BinaryScalarOp):
(k, x) = inputs (k, x) = inputs
(gz,) = grads (gz,) = grads
return [ return [
gz * gammainc_der(k, x), gz * gammainc_grad(k, x),
gz * exp(-x + (k - 1) * log(x) - gammaln(k)), gz * exp(-x + (k - 1) * log(x) - gammaln(k)),
] ]
...@@ -644,7 +649,7 @@ class GammaIncC(BinaryScalarOp): ...@@ -644,7 +649,7 @@ class GammaIncC(BinaryScalarOp):
(k, x) = inputs (k, x) = inputs
(gz,) = grads (gz,) = grads
return [ return [
gz * gammaincc_der(k, x), gz * gammaincc_grad(k, x),
gz * -exp(-x + (k - 1) * log(x) - gammaln(k)), gz * -exp(-x + (k - 1) * log(x) - gammaln(k)),
] ]
...@@ -675,162 +680,209 @@ class GammaIncC(BinaryScalarOp): ...@@ -675,162 +680,209 @@ class GammaIncC(BinaryScalarOp):
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
class GammaIncDer(BinaryScalarOp): def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
""" init = [as_scalar(x) for x in init]
Gradient of the the regularized lower gamma function (P) wrt to the first constant = [as_scalar(x) for x in constant]
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp` # Create dummy types, in case some variables have the same initial form
init_ = [x.type() for x in init]
constant_ = [x.type() for x in constant]
update_, until_ = inner_loop_fn(*init_, *constant_)
op = ScalarLoop(
init=init_,
constant=constant_,
update=update_,
until=until_,
until_condition_failed="warn",
name=name,
)
S, *_ = op(n_steps, *init, *constant)
return S
def gammainc_grad(k, x):
"""Gradient of the regularized lower gamma function (P) wrt to the first
argument (k, a.k.a. alpha).
Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions. Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481. ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
""" """
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
def impl(self, k, x): def grad_approx(skip_loop):
if x == 0: precision = np.array(1e-10, dtype=config.floatX)
return 0 max_iters = switch(
skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32")
sqrt_exp = -756 - x**2 + 60 * x )
if (
(k < 0.8 and x > 15)
or (k < 12 and x > 30)
or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp))
):
return -GammaIncCDer.st_impl(k, x)
precision = 1e-10
max_iters = int(1e5)
log_x = np.log(x) log_x = log(x)
log_gamma_k_plus_1 = scipy.special.gammaln(k + 1) log_gamma_k_plus_1 = gammaln(k + 1)
k_plus_n = k # First loop
k_plus_n = k # Should not overflow unless k > 2,147,383,647
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1 log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_a = 0.0 sum_a0 = np.array(0.0, dtype=dtype)
for n in range(0, max_iters + 1):
term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
sum_a += term
if term <= precision: def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):
break term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
sum_a += term
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
k_plus_n += 1 k_plus_n += 1
return (
(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n),
(term <= precision),
)
if n >= max_iters: init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
warnings.warn( constant = [log_x]
f"gammainc_der did not converge after {n} iterations", sum_a = _make_scalar_loop(
RuntimeWarning, max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
) )
return np.nan
k_plus_n = k # Second loop
n = np.array(0, dtype="int32")
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1 log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_b = 0.0 k_plus_n = k
for n in range(0, max_iters + 1): sum_b0 = np.array(0.0, dtype=dtype)
term = np.exp(
k_plus_n * log_x - log_gamma_k_plus_n_plus_1
) * scipy.special.digamma(k_plus_n + 1)
sum_b += term
if term <= precision and n >= 1: # Require at least two iterations def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x):
return np.exp(-x) * (log_x * sum_a - sum_b) term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1)
sum_b += term
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
n += 1
k_plus_n += 1 k_plus_n += 1
return (
(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n),
# Require at least two iterations
((term <= precision) & (n > 1)),
)
warnings.warn( init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n]
f"gammainc_der did not converge after {n} iterations", constant = [log_x]
RuntimeWarning, sum_b, *_ = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammainc_grad_b"
) )
return np.nan
def c_code(self, *args, **kwargs): grad_approx = exp(-x) * (log_x * sum_a - sum_b)
raise NotImplementedError() return grad_approx
zero_branch = eq(x, 0)
sqrt_exp = -756 - x**2 + 60 * x
gammaincc_branch = (
((k < 0.8) & (x > 15))
| ((k < 12) & (x > 30))
| ((sqrt_exp > 0) & (k < sqrt(sqrt_exp)))
)
grad = switch(
zero_branch,
0,
switch(
gammaincc_branch,
-gammaincc_grad(k, x, skip_loops=zero_branch | (~gammaincc_branch)),
grad_approx(skip_loop=zero_branch | gammaincc_branch),
),
)
return grad
gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")
def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
"""Gradient of the regularized upper gamma function (Q) wrt to the first
argument (k, a.k.a. alpha).
class GammaIncCDer(BinaryScalarOp): Adapted from STAN `grad_reg_inc_gamma.hpp`
"""
Gradient of the the regularized upper gamma function (Q) wrt to the first skip_loops is used for faster branching when this function is called by `gammainc_der`
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
""" """
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
@staticmethod gamma_k = gamma(k)
def st_impl(k, x): digamma_k = psi(k)
gamma_k = scipy.special.gamma(k) log_x = log(x)
digamma_k = scipy.special.digamma(k)
log_x = np.log(x)
# asymptotic expansion http://dlmf.nist.gov/8.11#E2 def approx_a(skip_loop):
if (x >= k) and (x >= 8): n_steps = switch(
S = 0 skip_loop, np.array(0, dtype="int32"), np.array(9, dtype="int32")
)
sum_a0 = np.array(0.0, dtype=dtype)
dfac = np.array(1.0, dtype=dtype)
xpow = x
k_minus_one_minus_n = k - 1 k_minus_one_minus_n = k - 1
fac = k_minus_one_minus_n fac = k_minus_one_minus_n
dfac = 1 delta = true_div(dfac, xpow)
xpow = x
delta = dfac / xpow
for n in range(1, 10): def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x):
k_minus_one_minus_n -= 1 sum_a += delta
S += delta
xpow *= x xpow *= x
k_minus_one_minus_n -= 1
dfac = k_minus_one_minus_n * dfac + fac dfac = k_minus_one_minus_n * dfac + fac
fac *= k_minus_one_minus_n fac *= k_minus_one_minus_n
delta = dfac / xpow delta = dfac / xpow
if np.isinf(delta): return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), ()
warnings.warn(
"gammaincc_der did not converge", init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
RuntimeWarning, constant = [x]
sum_a = _make_scalar_loop(
n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a"
) )
return np.nan grad_approx_a = (
gammaincc(k, x) * (log_x - digamma_k)
+ exp(-x + (k - 1) * log_x) * sum_a / gamma_k
)
return grad_approx_a
return ( def approx_b(skip_loop):
scipy.special.gammaincc(k, x) * (log_x - digamma_k) max_iters = switch(
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32")
) )
log_precision = np.array(np.log(1e-6), dtype=config.floatX)
# gradient of series expansion http://dlmf.nist.gov/8.7#E3 sum_b0 = np.array(0.0, dtype=dtype)
else: log_s = np.array(0.0, dtype=dtype)
log_precision = np.log(1e-6) s_sign = np.array(1, dtype="int8")
max_iters = int(1e5) n = np.array(1, dtype="int32")
S = 0 log_delta = log_s - 2 * log(k)
log_s = 0.0
s_sign = 1 def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
log_delta = log_s - 2 * np.log(k) delta = exp(log_delta)
for n in range(1, max_iters + 1): sum_b += switch(s_sign > 0, delta, -delta)
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
s_sign = -s_sign s_sign = -s_sign
log_s += log_x - np.log(n)
log_delta = log_s - 2 * np.log(n + k)
if np.isinf(log_delta): # log will cast >int16 to float64
warnings.warn( log_s_inc = log_x - log(n)
"gammaincc_der did not converge", if log_s_inc.type.dtype != log_s.type.dtype:
RuntimeWarning, log_s_inc = log_s_inc.astype(log_s.type.dtype)
) log_s += log_s_inc
return np.nan
if log_delta <= log_precision: new_log_delta = log_s - 2 * log(n + k)
if new_log_delta.type.dtype != log_delta.type.dtype:
new_log_delta = new_log_delta.astype(log_delta.type.dtype)
log_delta = new_log_delta
n += 1
return ( return (
scipy.special.gammainc(k, x) * (digamma_k - log_x) (sum_b, log_s, s_sign, log_delta, n),
+ np.exp(k * log_x) * S / gamma_k log_delta <= log_precision,
) )
warnings.warn( init = [sum_b0, log_s, s_sign, log_delta, n]
f"gammaincc_der did not converge after {n} iterations", constant = [k, log_x]
RuntimeWarning, sum_b = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
) )
return np.nan grad_approx_b = (
gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k
def impl(self, k, x): )
return self.st_impl(k, x) return grad_approx_b
def c_code(self, *args, **kwargs):
raise NotImplementedError()
gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der") branch_a = (x >= k) & (x >= 8)
return switch(
branch_a,
approx_a(skip_loop=~branch_a | skip_loops),
approx_b(skip_loop=branch_a | skip_loops),
)
class GammaU(BinaryScalarOp): class GammaU(BinaryScalarOp):
......
...@@ -3,6 +3,8 @@ from contextlib import ExitStack as does_not_warn ...@@ -3,6 +3,8 @@ from contextlib import ExitStack as does_not_warn
import numpy as np import numpy as np
import pytest import pytest
from pytensor.gradient import verify_grad
scipy = pytest.importorskip("scipy") scipy = pytest.importorskip("scipy")
...@@ -11,11 +13,11 @@ from functools import partial ...@@ -11,11 +13,11 @@ from functools import partial
import scipy.special import scipy.special
import scipy.stats import scipy.stats
from pytensor import function from pytensor import function, grad
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor import inplace from pytensor.tensor import gammaincc, inplace, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.utils import ( from tests.tensor.utils import (
_good_broadcast_unary_chi2sf, _good_broadcast_unary_chi2sf,
...@@ -387,6 +389,9 @@ def test_gammainc_ddk_tabulated_values(): ...@@ -387,6 +389,9 @@ def test_gammainc_ddk_tabulated_values():
gammaincc_ddk = at.grad(gammainc_out, k) gammaincc_ddk = at.grad(gammainc_out, k)
f_grad = function([k, x], gammaincc_ddk) f_grad = function([k, x], gammaincc_ddk)
rtol = 1e-5 if config.floatX == "float64" else 1e-2
atol = 1e-10 if config.floatX == "float64" else 1e-6
for test_k, test_x, expected_ddk in ( for test_k, test_x, expected_ddk in (
(0.0001, 0, 0), # Limit condition (0.0001, 0, 0), # Limit condition
(0.0001, 0.0001, -8.62594024578651), (0.0001, 0.0001, -8.62594024578651),
...@@ -421,10 +426,27 @@ def test_gammainc_ddk_tabulated_values(): ...@@ -421,10 +426,27 @@ def test_gammainc_ddk_tabulated_values():
(19.0001, 29.7501, -0.007828749832965796), (19.0001, 29.7501, -0.007828749832965796),
): ):
np.testing.assert_allclose( np.testing.assert_allclose(
f_grad(test_k, test_x), expected_ddk, rtol=1e-5, atol=1e-14 f_grad(test_k, test_x), expected_ddk, rtol=rtol, atol=atol
) )
def test_gammaincc_ddk_performance(benchmark):
rng = np.random.default_rng(1)
k = vector("k")
x = vector("x")
out = gammaincc(k, x)
grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN")
vals = [
# Values that hit the second branch of the gradient
np.full((1000,), 3.2),
np.full((1000,), 0.01),
]
verify_grad(gammaincc, vals, rng=rng)
benchmark(grad_fn, *vals)
TestGammaUBroadcast = makeBroadcastTester( TestGammaUBroadcast = makeBroadcastTester(
op=at.gammau, op=at.gammau,
expected=expected_gammau, expected=expected_gammau,
...@@ -796,7 +818,7 @@ class TestBetaIncGrad: ...@@ -796,7 +818,7 @@ class TestBetaIncGrad:
betainc_out = at.betainc(a, b, z) betainc_out = at.betainc(a, b, z)
betainc_grad = at.grad(betainc_out, [a, b]) betainc_grad = at.grad(betainc_out, [a, b])
f_grad = function([a, b, z], betainc_grad) f_grad = function([a, b, z], betainc_grad)
decimal = 7 if config.floatX == "float64" else 5
for test_a, test_b, test_z, expected_dda, expected_ddb in ( for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04), (1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04),
(1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04), (1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04),
...@@ -806,6 +828,7 @@ class TestBetaIncGrad: ...@@ -806,6 +828,7 @@ class TestBetaIncGrad:
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
f_grad(test_a, test_b, test_z), f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb], [expected_dda, expected_ddb],
decimal=decimal,
) )
def test_beta_inc_stan_grad_combined(self): def test_beta_inc_stan_grad_combined(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论