提交 86fe383d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for hyp2f1 gradient

上级 39d37df6
......@@ -5,7 +5,6 @@ As SciPy is not always available, we treat them separately.
"""
import os
import warnings
from textwrap import dedent
import numpy as np
......@@ -26,7 +25,9 @@ from pytensor.scalar.basic import (
expm1,
float64,
float_types,
floor,
identity,
integer_types,
isinf,
log,
log1p,
......@@ -853,15 +854,13 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
s_sign = -s_sign
# log will cast >int16 to float64
log_s_inc = log_x - log(n)
if log_s_inc.type.dtype != log_s.type.dtype:
log_s_inc = log_s_inc.astype(log_s.type.dtype)
log_s += log_s_inc
log_s += log_x - log(n)
if log_s.type.dtype != dtype:
log_s = log_s.astype(dtype)
new_log_delta = log_s - 2 * log(n + k)
if new_log_delta.type.dtype != log_delta.type.dtype:
new_log_delta = new_log_delta.astype(log_delta.type.dtype)
log_delta = new_log_delta
log_delta = log_s - 2 * log(n + k)
if log_delta.type.dtype != dtype:
log_delta = log_delta.astype(dtype)
n += 1
return (
......@@ -1581,9 +1580,9 @@ class Hyp2F1(ScalarOp):
a, b, c, z = inputs
(gz,) = grads
return [
gz * hyp2f1_der(a, b, c, z, wrt=0),
gz * hyp2f1_der(a, b, c, z, wrt=1),
gz * hyp2f1_der(a, b, c, z, wrt=2),
gz * hyp2f1_grad(a, b, c, z, wrt=0),
gz * hyp2f1_grad(a, b, c, z, wrt=1),
gz * hyp2f1_grad(a, b, c, z, wrt=2),
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
]
......@@ -1594,37 +1593,43 @@ class Hyp2F1(ScalarOp):
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
class Hyp2F1Der(ScalarOp):
"""
Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
"""
def _unsafe_sign(x):
# Unlike scalar.sign we don't worry about x being 0 or nan
return switch(x > 0, 1, -1)
nin = 5
def impl(self, a, b, c, z, wrt):
def check_2f1_converges(a, b, c, z) -> bool:
num_terms = 0
is_polynomial = False
def hyp2f1_grad(a, b, c, z, wrt: int):
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
def check_2f1_converges(a, b, c, z):
def is_nonpositive_integer(x):
return x <= 0 and x.is_integer()
if x.type.dtype not in integer_types:
return eq(floor(x), x) & (x <= 0)
else:
return x <= 0
if is_nonpositive_integer(a) and abs(a) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(a)))
if is_nonpositive_integer(b) and abs(b) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(b)))
a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
num_terms = switch(
a_is_polynomial,
floor(scalar_abs(a)).astype("int64"),
0,
)
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
num_terms = switch(
b_is_polynomial,
floor(scalar_abs(b)).astype("int64"),
num_terms,
)
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
is_polynomial = a_is_polynomial | b_is_polynomial
return not is_undefined and (
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
return (~is_undefined) & (
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
)
def compute_grad_2f1(a, b, c, z, wrt):
def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
"""
Notes
-----
......@@ -1653,75 +1658,100 @@ class Hyp2F1Der(ScalarOp):
elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
max_steps = int(1e6)
precision = 1e-14
min_steps = np.array(
10, dtype="int32"
) # https://github.com/stan-dev/math/issues/2857
max_steps = switch(
skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
)
precision = np.array(1e-14, dtype=config.floatX)
grad = np.array(0, dtype=dtype)
res = 0
log_g = np.array(-np.inf, dtype=dtype)
log_g_sign = np.array(1, dtype="int8")
if z == 0:
return res
log_t = np.array(0.0, dtype=dtype)
log_t_sign = np.array(1, dtype="int8")
log_g_old = -np.inf
log_t_old = 0.0
log_t_new = 0.0
sign_z = np.sign(z)
log_z = np.log(np.abs(z))
log_z = log(scalar_abs(z))
sign_z = _unsafe_sign(z)
log_g_old_sign = 1
log_t_old_sign = 1
log_t_new_sign = 1
sign_zk = sign_z
k = np.array(0, dtype="int32")
for k in range(max_steps):
def inner_loop(
grad,
log_g,
log_g_sign,
log_t,
log_t_sign,
sign_zk,
k,
a,
b,
c,
log_z,
sign_z,
):
p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p == 0:
return res
log_t_new += np.log(np.abs(p)) + log_z
log_t_new_sign = np.sign(p) * log_t_new_sign
if p.type.dtype != dtype:
p = p.astype(dtype)
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
term = log_g_sign * log_t_sign * exp(log_g - log_t)
if wrt_a:
term += np.reciprocal(a + k)
term += reciprocal(a + k)
elif wrt_b:
term += np.reciprocal(b + k)
term += reciprocal(b + k)
else:
term -= np.reciprocal(c + k)
term -= reciprocal(c + k)
log_g_old = log_t_new + np.log(np.abs(term))
log_g_old_sign = np.sign(term) * log_t_new_sign
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
res += g_current
if term.type.dtype != dtype:
term = term.astype(dtype)
log_t_old = log_t_new
log_t_old_sign = log_t_new_sign
sign_zk *= sign_z
if k >= min_steps and np.abs(g_current) <= precision:
return res
log_t = log_t + log(scalar_abs(p)) + log_z
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
log_g = log_t + log(scalar_abs(term))
log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
warnings.warn(
f"hyp2f1_der did not converge after {k} iterations",
RuntimeWarning,
)
return np.nan
g_current = log_g_sign * exp(log_g) * sign_zk
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
if not check_2f1_converges(a, b, c, z):
warnings.warn(
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
RuntimeWarning,
# If p==0, don't update grad and get out of while loop next
grad = switch(
eq(p, 0),
grad,
grad + g_current,
)
return np.nan
return compute_grad_2f1(a, b, c, z, wrt=wrt)
sign_zk *= sign_z
k += 1
def __call__(self, a, b, c, z, wrt, **kwargs):
# This allows wrt to be a keyword argument
return super().__call__(a, b, c, z, wrt, **kwargs)
return (
(grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k),
(eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))),
)
def c_code(self, *args, **kwargs):
raise NotImplementedError()
init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k]
constant = [a, b, c, log_z, sign_z]
grad = _make_scalar_loop(
max_steps, init, constant, inner_loop, name="hyp2f1_grad"
)
return switch(
eq(z, 0),
0,
grad,
)
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
z_is_zero = eq(z, 0)
converges = check_2f1_converges(a, b, c, z)
return switch(
z_is_zero,
0,
switch(
converges,
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
np.nan,
),
)
from contextlib import ExitStack as does_not_warn
import warnings
import numpy as np
import pytest
......@@ -872,7 +872,30 @@ TestHyp2F1InplaceBroadcast = makeBroadcastTester(
)
def test_hyp2f1_grad_stan_cases():
class TestHyp2F1Grad:
few_iters_case = (
2.0,
1.0,
2.0,
0.4,
0.4617734323582945,
0.851376039609984,
-0.4617734323582945,
2.777777777777778,
)
many_iters_case = (
3.70975,
1.0,
2.70975,
0.999696,
29369830.002773938200417693317785,
36347869.41885337,
-30843032.10697079073015067426929807,
26278034019.28811,
)
def test_hyp2f1_grad_stan_cases(self):
"""This test reuses the same test cases as in:
https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp
https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp
......@@ -880,12 +903,11 @@ def test_hyp2f1_grad_stan_cases():
Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests
"""
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
betainc_out = at.hyp2f1(a1, a2, b1, z)
betainc_grad = at.grad(betainc_out, [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], betainc_grad)
rtol = 1e-9 if config.floatX == "float64" else 1e-3
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
hyp2f1_grad = at.grad(hyp2f1_out, [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], hyp2f1_grad)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
for (
test_a1,
test_a2,
......@@ -987,26 +1009,8 @@ def test_hyp2f1_grad_stan_cases():
0.0677809985598383,
0.865295247272367,
),
(
2.0,
1.0,
2.0,
0.4,
0.4617734323582945,
0.851376039609984,
-0.4617734323582945,
2.777777777777778,
),
(
3.70975,
1.0,
2.70975,
0.999696,
29369830.002773938200417693317785,
36347869.41885337,
-30843032.10697079073015067426929807,
26278034019.28811,
),
self.few_iters_case,
self.many_iters_case,
# Cases where series does not converge
(1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf),
(1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf),
......@@ -1014,16 +1018,13 @@ def test_hyp2f1_grad_stan_cases():
# (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889),
(1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889),
):
expectation = (
pytest.warns(
RuntimeWarning, match="Hyp2F1 does not meet convergence conditions"
)
if np.any(
np.isnan([expected_dda1, expected_dda2, expected_ddb1, expected_ddz])
)
else does_not_warn()
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="divide by zero encountered in log",
)
with expectation:
result = np.array(f_grad(test_a1, test_a2, test_b1, test_z))
np.testing.assert_allclose(
......@@ -1031,3 +1032,23 @@ def test_hyp2f1_grad_stan_cases():
np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]),
rtol=rtol,
)
@pytest.mark.parametrize("case", (few_iters_case, many_iters_case))
@pytest.mark.parametrize("wrt", ("a", "all"))
def test_benchmark(self, case, wrt, benchmark):
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
hyp2f1_grad = at.grad(hyp2f1_out, wrt=a1 if wrt == "a" else [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], hyp2f1_grad)
(test_a1, test_a2, test_b1, test_z, *expected_dds) = case
result = benchmark(f_grad, test_a1, test_a2, test_b1, test_z)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
expected_result = expected_dds[0] if wrt == "a" else np.array(expected_dds)
np.testing.assert_allclose(
result,
expected_result,
rtol=rtol,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论