提交 9b2cb97e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Implement Hyp2F1 and gradients

上级 005a3a02
......@@ -1481,3 +1481,169 @@ class BetaIncDer(ScalarOp):
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
class Hyp2F1(ScalarOp):
"""
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
"""
nin = 4
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
@staticmethod
def st_impl(a, b, c, z):
return scipy.special.hyp2f1(a, b, c, z)
def impl(self, a, b, c, z):
return Hyp2F1.st_impl(a, b, c, z)
def grad(self, inputs, grads):
a, b, c, z = inputs
(gz,) = grads
return [
gz * hyp2f1_der(a, b, c, z, wrt=0),
gz * hyp2f1_der(a, b, c, z, wrt=1),
gz * hyp2f1_der(a, b, c, z, wrt=2),
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
class Hyp2F1Der(ScalarOp):
"""
Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
"""
nin = 5
def impl(self, a, b, c, z, wrt):
def check_2f1_converges(a, b, c, z) -> bool:
num_terms = 0
is_polynomial = False
def is_nonpositive_integer(x):
return x <= 0 and x.is_integer()
if is_nonpositive_integer(a) and abs(a) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(a)))
if is_nonpositive_integer(b) and abs(b) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(b)))
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
return not is_undefined and (
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
)
def compute_grad_2f1(a, b, c, z, wrt):
"""
Notes
-----
The algorithm can be derived by looking at the ratio of two successive terms in the series
β_{k+1}/β_{k} = A(k)/B(k)
β_{k+1} = A(k)/B(k) * β_{k}
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
by dropping the respective term
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
tracking their signs.
"""
wrt_a = wrt_b = False
if wrt == 0:
wrt_a = True
elif wrt == 1:
wrt_b = True
elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
max_steps = int(1e6)
precision = 1e-14
res = 0
if z == 0:
return res
log_g_old = -np.inf
log_t_old = 0.0
log_t_new = 0.0
sign_z = np.sign(z)
log_z = np.log(np.abs(z))
log_g_old_sign = 1
log_t_old_sign = 1
log_t_new_sign = 1
sign_zk = sign_z
for k in range(max_steps):
p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p == 0:
return res
log_t_new += np.log(np.abs(p)) + log_z
log_t_new_sign = np.sign(p) * log_t_new_sign
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
if wrt_a:
term += np.reciprocal(a + k)
elif wrt_b:
term += np.reciprocal(b + k)
else:
term -= np.reciprocal(c + k)
log_g_old = log_t_new + np.log(np.abs(term))
log_g_old_sign = np.sign(term) * log_t_new_sign
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
res += g_current
log_t_old = log_t_new
log_t_old_sign = log_t_new_sign
sign_zk *= sign_z
if k >= min_steps and np.abs(g_current) <= precision:
return res
warnings.warn(
f"hyp2f1_der did not converge after {k} iterations",
RuntimeWarning,
)
return np.nan
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
if not check_2f1_converges(a, b, c, z):
warnings.warn(
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
RuntimeWarning,
)
return np.nan
return compute_grad_2f1(a, b, c, z, wrt=wrt)
def __call__(self, a, b, c, z, wrt):
# This allows wrt to be a keyword argument
return super().__call__(a, b, c, z, wrt)
def c_code(self, *args, **kwargs):
raise NotImplementedError()
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
......@@ -392,6 +392,11 @@ def conj_inplace(a):
"""elementwise conjugate (inplace on `a`)"""
@scalar_elemwise
def hyp2f1_inplace(a, b, c, z):
"""gaussian hypergeometric function"""
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))
......
......@@ -1384,6 +1384,11 @@ def gammal(k, x):
"""Lower incomplete gamma function."""
@scalar_elemwise
def hyp2f1(a, b, c, z):
"""Gaussian hypergeometric function."""
@scalar_elemwise
def j0(x):
"""Bessel function of the first kind of order 0."""
......@@ -3132,4 +3137,5 @@ __all__ = [
"power",
"logaddexp",
"logsumexp",
"hyp2f1",
]
from contextlib import ExitStack as does_not_warn
import numpy as np
import pytest
......@@ -71,6 +73,7 @@ expected_i1 = scipy.special.i1
expected_iv = scipy.special.iv
expected_erfcx = scipy.special.erfcx
expected_sigmoid = scipy.special.expit
expected_hyp2f1 = scipy.special.hyp2f1
TestErfBroadcast = makeBroadcastTester(
op=at.erf,
......@@ -820,3 +823,189 @@ class TestBetaIncGrad:
np.testing.assert_allclose(
f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb]
)
_good_broadcast_quaternary_hyp2f1 = dict(
normal=(
random_ranged(0, 20, (2, 3)),
random_ranged(0, 20, (2, 3)),
random_ranged(0, 20, (2, 3)),
random_ranged(-0.9, 0.9, (2, 3)),
),
)
TestHyp2F1Broadcast = makeBroadcastTester(
op=at.hyp2f1,
expected=expected_hyp2f1,
good=_good_broadcast_quaternary_hyp2f1,
grad=_good_broadcast_quaternary_hyp2f1,
)
TestHyp2F1InplaceBroadcast = makeBroadcastTester(
op=inplace.hyp2f1_inplace,
expected=expected_hyp2f1,
good=_good_broadcast_quaternary_hyp2f1,
inplace=True,
)
def test_hyp2f1_grad_stan_cases():
"""This test reuses the same test cases as in:
https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp
https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp
Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests
"""
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
betainc_out = at.hyp2f1(a1, a2, b1, z)
betainc_grad = at.grad(betainc_out, [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], betainc_grad)
rtol = 1e-9 if config.floatX == "float64" else 1e-3
for (
test_a1,
test_a2,
test_b1,
test_z,
expected_dda1,
expected_dda2,
expected_ddb1,
expected_ddz,
) in (
(
3.70975,
1.0,
2.70975,
-0.2,
-0.0488658806159776,
-0.193844936204681,
0.0677809985598383,
0.8652952472723672,
),
(3.70975, 1.0, 2.70975, 0, 0, 0, 0, 1.369037734108313),
(
1.0,
1.0,
1.0,
0.6,
2.290726829685388,
2.290726829685388,
-2.290726829685388,
6.25,
),
(
1.0,
31.0,
41.0,
1.0,
6.825270649241036,
0.4938271604938271,
-0.382716049382716,
17.22222222222223,
),
(
1.0,
-2.1,
41.0,
1.0,
-0.04921317604093563,
0.02256814168279349,
0.00118482743834665,
-0.04854621426218426,
),
(
1.0,
-0.5,
10.6,
0.3,
-0.01443822031245647,
0.02829710651967078,
0.00136986255602642,
-0.04846036062115473,
),
(
1.0,
-0.5,
10.0,
0.3,
-0.0153218866216130,
0.02999436412836072,
0.0015413242328729,
-0.05144686244336445,
),
(
-0.5,
-4.5,
11.0,
0.3,
-0.1227022810085707,
-0.01298849638043795,
-0.0053540982315572,
0.1959735211840362,
),
(
-0.5,
-4.5,
-3.2,
0.9,
0.85880025358111,
0.4677704416159314,
-4.19010422485256,
-2.959196647856408,
),
(
3.70975,
1.0,
2.70975,
-0.2,
-0.0488658806159776,
-0.193844936204681,
0.0677809985598383,
0.865295247272367,
),
(
2.0,
1.0,
2.0,
0.4,
0.4617734323582945,
0.851376039609984,
-0.4617734323582945,
2.777777777777778,
),
(
3.70975,
1.0,
2.70975,
0.999696,
29369830.002773938200417693317785,
36347869.41885337,
-30843032.10697079073015067426929807,
26278034019.28811,
),
# Cases where series does not converge
(1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf),
(1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf),
# Case where series converges under Euler transform (not implemented!)
# (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889),
(1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889),
):
expectation = (
pytest.warns(
RuntimeWarning, match="Hyp2F1 does not meet convergence conditions"
)
if np.any(
np.isnan([expected_dda1, expected_dda2, expected_ddb1, expected_ddz])
)
else does_not_warn()
)
with expectation:
result = np.array(f_grad(test_a1, test_a2, test_b1, test_z))
np.testing.assert_allclose(
result,
np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]),
rtol=rtol,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论