提交 86fe383d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for hyp2f1 gradient

上级 39d37df6
......@@ -5,7 +5,6 @@ As SciPy is not always available, we treat them separately.
"""
import os
import warnings
from textwrap import dedent
import numpy as np
......@@ -26,7 +25,9 @@ from pytensor.scalar.basic import (
expm1,
float64,
float_types,
floor,
identity,
integer_types,
isinf,
log,
log1p,
......@@ -853,15 +854,13 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
s_sign = -s_sign
# log will cast >int16 to float64
log_s_inc = log_x - log(n)
if log_s_inc.type.dtype != log_s.type.dtype:
log_s_inc = log_s_inc.astype(log_s.type.dtype)
log_s += log_s_inc
log_s += log_x - log(n)
if log_s.type.dtype != dtype:
log_s = log_s.astype(dtype)
new_log_delta = log_s - 2 * log(n + k)
if new_log_delta.type.dtype != log_delta.type.dtype:
new_log_delta = new_log_delta.astype(log_delta.type.dtype)
log_delta = new_log_delta
log_delta = log_s - 2 * log(n + k)
if log_delta.type.dtype != dtype:
log_delta = log_delta.astype(dtype)
n += 1
return (
......@@ -1581,9 +1580,9 @@ class Hyp2F1(ScalarOp):
a, b, c, z = inputs
(gz,) = grads
return [
gz * hyp2f1_der(a, b, c, z, wrt=0),
gz * hyp2f1_der(a, b, c, z, wrt=1),
gz * hyp2f1_der(a, b, c, z, wrt=2),
gz * hyp2f1_grad(a, b, c, z, wrt=0),
gz * hyp2f1_grad(a, b, c, z, wrt=1),
gz * hyp2f1_grad(a, b, c, z, wrt=2),
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
]
......@@ -1594,134 +1593,165 @@ class Hyp2F1(ScalarOp):
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
class Hyp2F1Der(ScalarOp):
"""
Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
def _unsafe_sign(x):
# Unlike scalar.sign we don't worry about x being 0 or nan
return switch(x > 0, 1, -1)
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
"""
nin = 5
def hyp2f1_grad(a, b, c, z, wrt: int):
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
def impl(self, a, b, c, z, wrt):
def check_2f1_converges(a, b, c, z) -> bool:
num_terms = 0
is_polynomial = False
def check_2f1_converges(a, b, c, z):
def is_nonpositive_integer(x):
if x.type.dtype not in integer_types:
return eq(floor(x), x) & (x <= 0)
else:
return x <= 0
def is_nonpositive_integer(x):
return x <= 0 and x.is_integer()
a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
num_terms = switch(
a_is_polynomial,
floor(scalar_abs(a)).astype("int64"),
0,
)
if is_nonpositive_integer(a) and abs(a) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(a)))
if is_nonpositive_integer(b) and abs(b) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(b)))
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
num_terms = switch(
b_is_polynomial,
floor(scalar_abs(b)).astype("int64"),
num_terms,
)
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
is_polynomial = a_is_polynomial | b_is_polynomial
return not is_undefined and (
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
)
return (~is_undefined) & (
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
)
def compute_grad_2f1(a, b, c, z, wrt):
"""
Notes
-----
The algorithm can be derived by looking at the ratio of two successive terms in the series
β_{k+1}/β_{k} = A(k)/B(k)
β_{k+1} = A(k)/B(k) * β_{k}
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
by dropping the respective term
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
tracking their signs.
"""
def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
"""
Notes
-----
The algorithm can be derived by looking at the ratio of two successive terms in the series
β_{k+1}/β_{k} = A(k)/B(k)
β_{k+1} = A(k)/B(k) * β_{k}
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
by dropping the respective term
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
tracking their signs.
"""
wrt_a = wrt_b = False
if wrt == 0:
wrt_a = True
elif wrt == 1:
wrt_b = True
elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
min_steps = np.array(
10, dtype="int32"
) # https://github.com/stan-dev/math/issues/2857
max_steps = switch(
skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
)
precision = np.array(1e-14, dtype=config.floatX)
wrt_a = wrt_b = False
if wrt == 0:
wrt_a = True
elif wrt == 1:
wrt_b = True
elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
max_steps = int(1e6)
precision = 1e-14
res = 0
if z == 0:
return res
log_g_old = -np.inf
log_t_old = 0.0
log_t_new = 0.0
sign_z = np.sign(z)
log_z = np.log(np.abs(z))
log_g_old_sign = 1
log_t_old_sign = 1
log_t_new_sign = 1
sign_zk = sign_z
for k in range(max_steps):
p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p == 0:
return res
log_t_new += np.log(np.abs(p)) + log_z
log_t_new_sign = np.sign(p) * log_t_new_sign
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
if wrt_a:
term += np.reciprocal(a + k)
elif wrt_b:
term += np.reciprocal(b + k)
else:
term -= np.reciprocal(c + k)
log_g_old = log_t_new + np.log(np.abs(term))
log_g_old_sign = np.sign(term) * log_t_new_sign
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
res += g_current
log_t_old = log_t_new
log_t_old_sign = log_t_new_sign
sign_zk *= sign_z
if k >= min_steps and np.abs(g_current) <= precision:
return res
warnings.warn(
f"hyp2f1_der did not converge after {k} iterations",
RuntimeWarning,
)
return np.nan
grad = np.array(0, dtype=dtype)
log_g = np.array(-np.inf, dtype=dtype)
log_g_sign = np.array(1, dtype="int8")
log_t = np.array(0.0, dtype=dtype)
log_t_sign = np.array(1, dtype="int8")
log_z = log(scalar_abs(z))
sign_z = _unsafe_sign(z)
sign_zk = sign_z
k = np.array(0, dtype="int32")
def inner_loop(
grad,
log_g,
log_g_sign,
log_t,
log_t_sign,
sign_zk,
k,
a,
b,
c,
log_z,
sign_z,
):
p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p.type.dtype != dtype:
p = p.astype(dtype)
term = log_g_sign * log_t_sign * exp(log_g - log_t)
if wrt_a:
term += reciprocal(a + k)
elif wrt_b:
term += reciprocal(b + k)
else:
term -= reciprocal(c + k)
if term.type.dtype != dtype:
term = term.astype(dtype)
log_t = log_t + log(scalar_abs(p)) + log_z
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
log_g = log_t + log(scalar_abs(term))
log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
g_current = log_g_sign * exp(log_g) * sign_zk
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
if not check_2f1_converges(a, b, c, z):
warnings.warn(
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
RuntimeWarning,
# If p==0, don't update grad and get out of while loop next
grad = switch(
eq(p, 0),
grad,
grad + g_current,
)
return np.nan
return compute_grad_2f1(a, b, c, z, wrt=wrt)
sign_zk *= sign_z
k += 1
def __call__(self, a, b, c, z, wrt, **kwargs):
# This allows wrt to be a keyword argument
return super().__call__(a, b, c, z, wrt, **kwargs)
return (
(grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k),
(eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))),
)
def c_code(self, *args, **kwargs):
raise NotImplementedError()
init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k]
constant = [a, b, c, log_z, sign_z]
grad = _make_scalar_loop(
max_steps, init, constant, inner_loop, name="hyp2f1_grad"
)
return switch(
eq(z, 0),
0,
grad,
)
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
z_is_zero = eq(z, 0)
converges = check_2f1_converges(a, b, c, z)
return switch(
z_is_zero,
0,
switch(
converges,
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
np.nan,
),
)
from contextlib import ExitStack as does_not_warn
import warnings
import numpy as np
import pytest
......@@ -872,162 +872,183 @@ TestHyp2F1InplaceBroadcast = makeBroadcastTester(
)
def test_hyp2f1_grad_stan_cases():
"""This test reuses the same test cases as in:
https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp
https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp
Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests
"""
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
betainc_out = at.hyp2f1(a1, a2, b1, z)
betainc_grad = at.grad(betainc_out, [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], betainc_grad)
rtol = 1e-9 if config.floatX == "float64" else 1e-3
for (
test_a1,
test_a2,
test_b1,
test_z,
expected_dda1,
expected_dda2,
expected_ddb1,
expected_ddz,
) in (
(
3.70975,
1.0,
2.70975,
-0.2,
-0.0488658806159776,
-0.193844936204681,
0.0677809985598383,
0.8652952472723672,
),
(3.70975, 1.0, 2.70975, 0, 0, 0, 0, 1.369037734108313),
(
1.0,
1.0,
1.0,
0.6,
2.290726829685388,
2.290726829685388,
-2.290726829685388,
6.25,
),
(
1.0,
31.0,
41.0,
1.0,
6.825270649241036,
0.4938271604938271,
-0.382716049382716,
17.22222222222223,
),
(
1.0,
-2.1,
41.0,
1.0,
-0.04921317604093563,
0.02256814168279349,
0.00118482743834665,
-0.04854621426218426,
),
(
1.0,
-0.5,
10.6,
0.3,
-0.01443822031245647,
0.02829710651967078,
0.00136986255602642,
-0.04846036062115473,
),
(
1.0,
-0.5,
10.0,
0.3,
-0.0153218866216130,
0.02999436412836072,
0.0015413242328729,
-0.05144686244336445,
),
(
-0.5,
-4.5,
11.0,
0.3,
-0.1227022810085707,
-0.01298849638043795,
-0.0053540982315572,
0.1959735211840362,
),
(
-0.5,
-4.5,
-3.2,
0.9,
0.85880025358111,
0.4677704416159314,
-4.19010422485256,
-2.959196647856408,
),
(
3.70975,
1.0,
2.70975,
-0.2,
-0.0488658806159776,
-0.193844936204681,
0.0677809985598383,
0.865295247272367,
),
(
2.0,
1.0,
2.0,
0.4,
0.4617734323582945,
0.851376039609984,
-0.4617734323582945,
2.777777777777778,
),
(
3.70975,
1.0,
2.70975,
0.999696,
29369830.002773938200417693317785,
36347869.41885337,
-30843032.10697079073015067426929807,
26278034019.28811,
),
# Cases where series does not converge
(1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf),
(1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf),
# Case where series converges under Euler transform (not implemented!)
# (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889),
(1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889),
):
expectation = (
pytest.warns(
RuntimeWarning, match="Hyp2F1 does not meet convergence conditions"
)
if np.any(
np.isnan([expected_dda1, expected_dda2, expected_ddb1, expected_ddz])
class TestHyp2F1Grad:
few_iters_case = (
2.0,
1.0,
2.0,
0.4,
0.4617734323582945,
0.851376039609984,
-0.4617734323582945,
2.777777777777778,
)
many_iters_case = (
3.70975,
1.0,
2.70975,
0.999696,
29369830.002773938200417693317785,
36347869.41885337,
-30843032.10697079073015067426929807,
26278034019.28811,
)
def test_hyp2f1_grad_stan_cases(self):
"""This test reuses the same test cases as in:
https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp
https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp
Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests
"""
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
hyp2f1_grad = at.grad(hyp2f1_out, [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], hyp2f1_grad)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
for (
test_a1,
test_a2,
test_b1,
test_z,
expected_dda1,
expected_dda2,
expected_ddb1,
expected_ddz,
) in (
(
3.70975,
1.0,
2.70975,
-0.2,
-0.0488658806159776,
-0.193844936204681,
0.0677809985598383,
0.8652952472723672,
),
(3.70975, 1.0, 2.70975, 0, 0, 0, 0, 1.369037734108313),
(
1.0,
1.0,
1.0,
0.6,
2.290726829685388,
2.290726829685388,
-2.290726829685388,
6.25,
),
(
1.0,
31.0,
41.0,
1.0,
6.825270649241036,
0.4938271604938271,
-0.382716049382716,
17.22222222222223,
),
(
1.0,
-2.1,
41.0,
1.0,
-0.04921317604093563,
0.02256814168279349,
0.00118482743834665,
-0.04854621426218426,
),
(
1.0,
-0.5,
10.6,
0.3,
-0.01443822031245647,
0.02829710651967078,
0.00136986255602642,
-0.04846036062115473,
),
(
1.0,
-0.5,
10.0,
0.3,
-0.0153218866216130,
0.02999436412836072,
0.0015413242328729,
-0.05144686244336445,
),
(
-0.5,
-4.5,
11.0,
0.3,
-0.1227022810085707,
-0.01298849638043795,
-0.0053540982315572,
0.1959735211840362,
),
(
-0.5,
-4.5,
-3.2,
0.9,
0.85880025358111,
0.4677704416159314,
-4.19010422485256,
-2.959196647856408,
),
(
3.70975,
1.0,
2.70975,
-0.2,
-0.0488658806159776,
-0.193844936204681,
0.0677809985598383,
0.865295247272367,
),
self.few_iters_case,
self.many_iters_case,
# Cases where series does not converge
(1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf),
(1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf),
# Case where series converges under Euler transform (not implemented!)
# (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889),
(1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889),
):
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="divide by zero encountered in log",
)
result = np.array(f_grad(test_a1, test_a2, test_b1, test_z))
np.testing.assert_allclose(
result,
np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]),
rtol=rtol,
)
else does_not_warn()
)
with expectation:
result = np.array(f_grad(test_a1, test_a2, test_b1, test_z))
@pytest.mark.parametrize("case", (few_iters_case, many_iters_case))
@pytest.mark.parametrize("wrt", ("a", "all"))
def test_benchmark(self, case, wrt, benchmark):
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
hyp2f1_grad = at.grad(hyp2f1_out, wrt=a1 if wrt == "a" else [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], hyp2f1_grad)
(test_a1, test_a2, test_b1, test_z, *expected_dds) = case
result = benchmark(f_grad, test_a1, test_a2, test_b1, test_z)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
expected_result = expected_dds[0] if wrt == "a" else np.array(expected_dds)
np.testing.assert_allclose(
result,
np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]),
expected_result,
rtol=rtol,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论