提交 86fe383d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for hyp2f1 gradient

上级 39d37df6
...@@ -5,7 +5,6 @@ As SciPy is not always available, we treat them separately. ...@@ -5,7 +5,6 @@ As SciPy is not always available, we treat them separately.
""" """
import os import os
import warnings
from textwrap import dedent from textwrap import dedent
import numpy as np import numpy as np
...@@ -26,7 +25,9 @@ from pytensor.scalar.basic import ( ...@@ -26,7 +25,9 @@ from pytensor.scalar.basic import (
expm1, expm1,
float64, float64,
float_types, float_types,
floor,
identity, identity,
integer_types,
isinf, isinf,
log, log,
log1p, log1p,
...@@ -853,15 +854,13 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): ...@@ -853,15 +854,13 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
s_sign = -s_sign s_sign = -s_sign
# log will cast >int16 to float64 # log will cast >int16 to float64
log_s_inc = log_x - log(n) log_s += log_x - log(n)
if log_s_inc.type.dtype != log_s.type.dtype: if log_s.type.dtype != dtype:
log_s_inc = log_s_inc.astype(log_s.type.dtype) log_s = log_s.astype(dtype)
log_s += log_s_inc
new_log_delta = log_s - 2 * log(n + k) log_delta = log_s - 2 * log(n + k)
if new_log_delta.type.dtype != log_delta.type.dtype: if log_delta.type.dtype != dtype:
new_log_delta = new_log_delta.astype(log_delta.type.dtype) log_delta = log_delta.astype(dtype)
log_delta = new_log_delta
n += 1 n += 1
return ( return (
...@@ -1581,9 +1580,9 @@ class Hyp2F1(ScalarOp): ...@@ -1581,9 +1580,9 @@ class Hyp2F1(ScalarOp):
a, b, c, z = inputs a, b, c, z = inputs
(gz,) = grads (gz,) = grads
return [ return [
gz * hyp2f1_der(a, b, c, z, wrt=0), gz * hyp2f1_grad(a, b, c, z, wrt=0),
gz * hyp2f1_der(a, b, c, z, wrt=1), gz * hyp2f1_grad(a, b, c, z, wrt=1),
gz * hyp2f1_der(a, b, c, z, wrt=2), gz * hyp2f1_grad(a, b, c, z, wrt=2),
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z), gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
] ]
...@@ -1594,134 +1593,165 @@ class Hyp2F1(ScalarOp): ...@@ -1594,134 +1593,165 @@ class Hyp2F1(ScalarOp):
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1") hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
class Hyp2F1Der(ScalarOp): def _unsafe_sign(x):
""" # Unlike scalar.sign we don't worry about x being 0 or nan
Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs. return switch(x > 0, 1, -1)
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
"""
nin = 5 def hyp2f1_grad(a, b, c, z, wrt: int):
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
def impl(self, a, b, c, z, wrt): def check_2f1_converges(a, b, c, z):
def check_2f1_converges(a, b, c, z) -> bool: def is_nonpositive_integer(x):
num_terms = 0 if x.type.dtype not in integer_types:
is_polynomial = False return eq(floor(x), x) & (x <= 0)
else:
return x <= 0
def is_nonpositive_integer(x): a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
return x <= 0 and x.is_integer() num_terms = switch(
a_is_polynomial,
floor(scalar_abs(a)).astype("int64"),
0,
)
if is_nonpositive_integer(a) and abs(a) >= num_terms: b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
is_polynomial = True num_terms = switch(
num_terms = int(np.floor(abs(a))) b_is_polynomial,
if is_nonpositive_integer(b) and abs(b) >= num_terms: floor(scalar_abs(b)).astype("int64"),
is_polynomial = True num_terms,
num_terms = int(np.floor(abs(b))) )
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
is_polynomial = a_is_polynomial | b_is_polynomial
return not is_undefined and ( return (~is_undefined) & (
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b)) is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
) )
def compute_grad_2f1(a, b, c, z, wrt): def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
""" """
Notes Notes
----- -----
The algorithm can be derived by looking at the ratio of two successive terms in the series The algorithm can be derived by looking at the ratio of two successive terms in the series
β_{k+1}/β_{k} = A(k)/B(k) β_{k+1}/β_{k} = A(k)/B(k)
β_{k+1} = A(k)/B(k) * β_{k} β_{k+1} = A(k)/B(k) * β_{k}
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k), The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
by dropping the respective term by dropping the respective term
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k) d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k) d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k) d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
tracking their signs. tracking their signs.
""" """
wrt_a = wrt_b = False
if wrt == 0:
wrt_a = True
elif wrt == 1:
wrt_b = True
elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
min_steps = np.array(
10, dtype="int32"
) # https://github.com/stan-dev/math/issues/2857
max_steps = switch(
skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
)
precision = np.array(1e-14, dtype=config.floatX)
wrt_a = wrt_b = False grad = np.array(0, dtype=dtype)
if wrt == 0:
wrt_a = True log_g = np.array(-np.inf, dtype=dtype)
elif wrt == 1: log_g_sign = np.array(1, dtype="int8")
wrt_b = True
elif wrt != 2: log_t = np.array(0.0, dtype=dtype)
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}") log_t_sign = np.array(1, dtype="int8")
min_steps = 10 # https://github.com/stan-dev/math/issues/2857 log_z = log(scalar_abs(z))
max_steps = int(1e6) sign_z = _unsafe_sign(z)
precision = 1e-14
sign_zk = sign_z
res = 0 k = np.array(0, dtype="int32")
if z == 0: def inner_loop(
return res grad,
log_g,
log_g_old = -np.inf log_g_sign,
log_t_old = 0.0 log_t,
log_t_new = 0.0 log_t_sign,
sign_z = np.sign(z) sign_zk,
log_z = np.log(np.abs(z)) k,
a,
log_g_old_sign = 1 b,
log_t_old_sign = 1 c,
log_t_new_sign = 1 log_z,
sign_zk = sign_z sign_z,
):
for k in range(max_steps): p = (a + k) * (b + k) / ((c + k) * (k + 1))
p = (a + k) * (b + k) / ((c + k) * (k + 1)) if p.type.dtype != dtype:
if p == 0: p = p.astype(dtype)
return res
log_t_new += np.log(np.abs(p)) + log_z term = log_g_sign * log_t_sign * exp(log_g - log_t)
log_t_new_sign = np.sign(p) * log_t_new_sign if wrt_a:
term += reciprocal(a + k)
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old) elif wrt_b:
if wrt_a: term += reciprocal(b + k)
term += np.reciprocal(a + k) else:
elif wrt_b: term -= reciprocal(c + k)
term += np.reciprocal(b + k)
else: if term.type.dtype != dtype:
term -= np.reciprocal(c + k) term = term.astype(dtype)
log_g_old = log_t_new + np.log(np.abs(term)) log_t = log_t + log(scalar_abs(p)) + log_z
log_g_old_sign = np.sign(term) * log_t_new_sign log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk log_g = log_t + log(scalar_abs(term))
res += g_current log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
log_t_old = log_t_new g_current = log_g_sign * exp(log_g) * sign_zk
log_t_old_sign = log_t_new_sign
sign_zk *= sign_z
if k >= min_steps and np.abs(g_current) <= precision:
return res
warnings.warn(
f"hyp2f1_der did not converge after {k} iterations",
RuntimeWarning,
)
return np.nan
# TODO: We could implement the Euler transform to expand supported domain, as Stan does # If p==0, don't update grad and get out of while loop next
if not check_2f1_converges(a, b, c, z): grad = switch(
warnings.warn( eq(p, 0),
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}", grad,
RuntimeWarning, grad + g_current,
) )
return np.nan
return compute_grad_2f1(a, b, c, z, wrt=wrt) sign_zk *= sign_z
k += 1
def __call__(self, a, b, c, z, wrt, **kwargs): return (
# This allows wrt to be a keyword argument (grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k),
return super().__call__(a, b, c, z, wrt, **kwargs) (eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))),
)
def c_code(self, *args, **kwargs): init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k]
raise NotImplementedError() constant = [a, b, c, log_z, sign_z]
grad = _make_scalar_loop(
max_steps, init, constant, inner_loop, name="hyp2f1_grad"
)
return switch(
eq(z, 0),
0,
grad,
)
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der") # We have to pass the converges flag to interrupt the loop, as the switch is not lazy
z_is_zero = eq(z, 0)
converges = check_2f1_converges(a, b, c, z)
return switch(
z_is_zero,
0,
switch(
converges,
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
np.nan,
),
)
from contextlib import ExitStack as does_not_warn import warnings
import numpy as np import numpy as np
import pytest import pytest
...@@ -872,162 +872,183 @@ TestHyp2F1InplaceBroadcast = makeBroadcastTester( ...@@ -872,162 +872,183 @@ TestHyp2F1InplaceBroadcast = makeBroadcastTester(
) )
def test_hyp2f1_grad_stan_cases(): class TestHyp2F1Grad:
"""This test reuses the same test cases as in: few_iters_case = (
https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp 2.0,
https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp 1.0,
2.0,
Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests 0.4,
""" 0.4617734323582945,
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z") 0.851376039609984,
betainc_out = at.hyp2f1(a1, a2, b1, z) -0.4617734323582945,
betainc_grad = at.grad(betainc_out, [a1, a2, b1, z]) 2.777777777777778,
f_grad = function([a1, a2, b1, z], betainc_grad) )
rtol = 1e-9 if config.floatX == "float64" else 1e-3 many_iters_case = (
3.70975,
for ( 1.0,
test_a1, 2.70975,
test_a2, 0.999696,
test_b1, 29369830.002773938200417693317785,
test_z, 36347869.41885337,
expected_dda1, -30843032.10697079073015067426929807,
expected_dda2, 26278034019.28811,
expected_ddb1, )
expected_ddz,
) in ( def test_hyp2f1_grad_stan_cases(self):
( """This test reuses the same test cases as in:
3.70975, https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp
1.0, https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp
2.70975,
-0.2, Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests
-0.0488658806159776, """
-0.193844936204681, a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
0.0677809985598383, hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
0.8652952472723672, hyp2f1_grad = at.grad(hyp2f1_out, [a1, a2, b1, z])
), f_grad = function([a1, a2, b1, z], hyp2f1_grad)
(3.70975, 1.0, 2.70975, 0, 0, 0, 0, 1.369037734108313),
( rtol = 1e-9 if config.floatX == "float64" else 2e-3
1.0, for (
1.0, test_a1,
1.0, test_a2,
0.6, test_b1,
2.290726829685388, test_z,
2.290726829685388, expected_dda1,
-2.290726829685388, expected_dda2,
6.25, expected_ddb1,
), expected_ddz,
( ) in (
1.0, (
31.0, 3.70975,
41.0, 1.0,
1.0, 2.70975,
6.825270649241036, -0.2,
0.4938271604938271, -0.0488658806159776,
-0.382716049382716, -0.193844936204681,
17.22222222222223, 0.0677809985598383,
), 0.8652952472723672,
( ),
1.0, (3.70975, 1.0, 2.70975, 0, 0, 0, 0, 1.369037734108313),
-2.1, (
41.0, 1.0,
1.0, 1.0,
-0.04921317604093563, 1.0,
0.02256814168279349, 0.6,
0.00118482743834665, 2.290726829685388,
-0.04854621426218426, 2.290726829685388,
), -2.290726829685388,
( 6.25,
1.0, ),
-0.5, (
10.6, 1.0,
0.3, 31.0,
-0.01443822031245647, 41.0,
0.02829710651967078, 1.0,
0.00136986255602642, 6.825270649241036,
-0.04846036062115473, 0.4938271604938271,
), -0.382716049382716,
( 17.22222222222223,
1.0, ),
-0.5, (
10.0, 1.0,
0.3, -2.1,
-0.0153218866216130, 41.0,
0.02999436412836072, 1.0,
0.0015413242328729, -0.04921317604093563,
-0.05144686244336445, 0.02256814168279349,
), 0.00118482743834665,
( -0.04854621426218426,
-0.5, ),
-4.5, (
11.0, 1.0,
0.3, -0.5,
-0.1227022810085707, 10.6,
-0.01298849638043795, 0.3,
-0.0053540982315572, -0.01443822031245647,
0.1959735211840362, 0.02829710651967078,
), 0.00136986255602642,
( -0.04846036062115473,
-0.5, ),
-4.5, (
-3.2, 1.0,
0.9, -0.5,
0.85880025358111, 10.0,
0.4677704416159314, 0.3,
-4.19010422485256, -0.0153218866216130,
-2.959196647856408, 0.02999436412836072,
), 0.0015413242328729,
( -0.05144686244336445,
3.70975, ),
1.0, (
2.70975, -0.5,
-0.2, -4.5,
-0.0488658806159776, 11.0,
-0.193844936204681, 0.3,
0.0677809985598383, -0.1227022810085707,
0.865295247272367, -0.01298849638043795,
), -0.0053540982315572,
( 0.1959735211840362,
2.0, ),
1.0, (
2.0, -0.5,
0.4, -4.5,
0.4617734323582945, -3.2,
0.851376039609984, 0.9,
-0.4617734323582945, 0.85880025358111,
2.777777777777778, 0.4677704416159314,
), -4.19010422485256,
( -2.959196647856408,
3.70975, ),
1.0, (
2.70975, 3.70975,
0.999696, 1.0,
29369830.002773938200417693317785, 2.70975,
36347869.41885337, -0.2,
-30843032.10697079073015067426929807, -0.0488658806159776,
26278034019.28811, -0.193844936204681,
), 0.0677809985598383,
# Cases where series does not converge 0.865295247272367,
(1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf), ),
(1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf), self.few_iters_case,
# Case where series converges under Euler transform (not implemented!) self.many_iters_case,
# (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889), # Cases where series does not converge
(1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889), (1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf),
): (1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf),
expectation = ( # Case where series converges under Euler transform (not implemented!)
pytest.warns( # (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889),
RuntimeWarning, match="Hyp2F1 does not meet convergence conditions" (1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889),
) ):
if np.any( with warnings.catch_warnings():
np.isnan([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]) warnings.simplefilter("error")
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="divide by zero encountered in log",
)
result = np.array(f_grad(test_a1, test_a2, test_b1, test_z))
np.testing.assert_allclose(
result,
np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]),
rtol=rtol,
) )
else does_not_warn()
)
with expectation:
result = np.array(f_grad(test_a1, test_a2, test_b1, test_z))
@pytest.mark.parametrize("case", (few_iters_case, many_iters_case))
@pytest.mark.parametrize("wrt", ("a", "all"))
def test_benchmark(self, case, wrt, benchmark):
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
hyp2f1_grad = at.grad(hyp2f1_out, wrt=a1 if wrt == "a" else [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], hyp2f1_grad)
(test_a1, test_a2, test_b1, test_z, *expected_dds) = case
result = benchmark(f_grad, test_a1, test_a2, test_b1, test_z)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
expected_result = expected_dds[0] if wrt == "a" else np.array(expected_dds)
np.testing.assert_allclose( np.testing.assert_allclose(
result, result,
np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]), expected_result,
rtol=rtol, rtol=rtol,
) )
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论