提交 86fe383d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for hyp2f1 gradient

上级 39d37df6
...@@ -5,7 +5,6 @@ As SciPy is not always available, we treat them separately. ...@@ -5,7 +5,6 @@ As SciPy is not always available, we treat them separately.
""" """
import os import os
import warnings
from textwrap import dedent from textwrap import dedent
import numpy as np import numpy as np
...@@ -26,7 +25,9 @@ from pytensor.scalar.basic import ( ...@@ -26,7 +25,9 @@ from pytensor.scalar.basic import (
expm1, expm1,
float64, float64,
float_types, float_types,
floor,
identity, identity,
integer_types,
isinf, isinf,
log, log,
log1p, log1p,
...@@ -853,15 +854,13 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): ...@@ -853,15 +854,13 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
s_sign = -s_sign s_sign = -s_sign
# log will cast >int16 to float64 # log will cast >int16 to float64
log_s_inc = log_x - log(n) log_s += log_x - log(n)
if log_s_inc.type.dtype != log_s.type.dtype: if log_s.type.dtype != dtype:
log_s_inc = log_s_inc.astype(log_s.type.dtype) log_s = log_s.astype(dtype)
log_s += log_s_inc
new_log_delta = log_s - 2 * log(n + k) log_delta = log_s - 2 * log(n + k)
if new_log_delta.type.dtype != log_delta.type.dtype: if log_delta.type.dtype != dtype:
new_log_delta = new_log_delta.astype(log_delta.type.dtype) log_delta = log_delta.astype(dtype)
log_delta = new_log_delta
n += 1 n += 1
return ( return (
...@@ -1581,9 +1580,9 @@ class Hyp2F1(ScalarOp): ...@@ -1581,9 +1580,9 @@ class Hyp2F1(ScalarOp):
a, b, c, z = inputs a, b, c, z = inputs
(gz,) = grads (gz,) = grads
return [ return [
gz * hyp2f1_der(a, b, c, z, wrt=0), gz * hyp2f1_grad(a, b, c, z, wrt=0),
gz * hyp2f1_der(a, b, c, z, wrt=1), gz * hyp2f1_grad(a, b, c, z, wrt=1),
gz * hyp2f1_der(a, b, c, z, wrt=2), gz * hyp2f1_grad(a, b, c, z, wrt=2),
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z), gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
] ]
...@@ -1594,37 +1593,43 @@ class Hyp2F1(ScalarOp): ...@@ -1594,37 +1593,43 @@ class Hyp2F1(ScalarOp):
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1") hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
class Hyp2F1Der(ScalarOp): def _unsafe_sign(x):
""" # Unlike scalar.sign we don't worry about x being 0 or nan
Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs. return switch(x > 0, 1, -1)
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
"""
nin = 5
def impl(self, a, b, c, z, wrt): def hyp2f1_grad(a, b, c, z, wrt: int):
def check_2f1_converges(a, b, c, z) -> bool: dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
num_terms = 0
is_polynomial = False
def check_2f1_converges(a, b, c, z):
def is_nonpositive_integer(x): def is_nonpositive_integer(x):
return x <= 0 and x.is_integer() if x.type.dtype not in integer_types:
return eq(floor(x), x) & (x <= 0)
else:
return x <= 0
if is_nonpositive_integer(a) and abs(a) >= num_terms: a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
is_polynomial = True num_terms = switch(
num_terms = int(np.floor(abs(a))) a_is_polynomial,
if is_nonpositive_integer(b) and abs(b) >= num_terms: floor(scalar_abs(a)).astype("int64"),
is_polynomial = True 0,
num_terms = int(np.floor(abs(b))) )
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
num_terms = switch(
b_is_polynomial,
floor(scalar_abs(b)).astype("int64"),
num_terms,
)
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
is_polynomial = a_is_polynomial | b_is_polynomial
return not is_undefined and ( return (~is_undefined) & (
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b)) is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
) )
def compute_grad_2f1(a, b, c, z, wrt): def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
""" """
Notes Notes
----- -----
...@@ -1653,75 +1658,100 @@ class Hyp2F1Der(ScalarOp): ...@@ -1653,75 +1658,100 @@ class Hyp2F1Der(ScalarOp):
elif wrt != 2: elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}") raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
min_steps = 10 # https://github.com/stan-dev/math/issues/2857 min_steps = np.array(
max_steps = int(1e6) 10, dtype="int32"
precision = 1e-14 ) # https://github.com/stan-dev/math/issues/2857
max_steps = switch(
skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
)
precision = np.array(1e-14, dtype=config.floatX)
grad = np.array(0, dtype=dtype)
res = 0 log_g = np.array(-np.inf, dtype=dtype)
log_g_sign = np.array(1, dtype="int8")
if z == 0: log_t = np.array(0.0, dtype=dtype)
return res log_t_sign = np.array(1, dtype="int8")
log_g_old = -np.inf log_z = log(scalar_abs(z))
log_t_old = 0.0 sign_z = _unsafe_sign(z)
log_t_new = 0.0
sign_z = np.sign(z)
log_z = np.log(np.abs(z))
log_g_old_sign = 1
log_t_old_sign = 1
log_t_new_sign = 1
sign_zk = sign_z sign_zk = sign_z
k = np.array(0, dtype="int32")
for k in range(max_steps): def inner_loop(
grad,
log_g,
log_g_sign,
log_t,
log_t_sign,
sign_zk,
k,
a,
b,
c,
log_z,
sign_z,
):
p = (a + k) * (b + k) / ((c + k) * (k + 1)) p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p == 0: if p.type.dtype != dtype:
return res p = p.astype(dtype)
log_t_new += np.log(np.abs(p)) + log_z
log_t_new_sign = np.sign(p) * log_t_new_sign
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old) term = log_g_sign * log_t_sign * exp(log_g - log_t)
if wrt_a: if wrt_a:
term += np.reciprocal(a + k) term += reciprocal(a + k)
elif wrt_b: elif wrt_b:
term += np.reciprocal(b + k) term += reciprocal(b + k)
else: else:
term -= np.reciprocal(c + k) term -= reciprocal(c + k)
log_g_old = log_t_new + np.log(np.abs(term)) if term.type.dtype != dtype:
log_g_old_sign = np.sign(term) * log_t_new_sign term = term.astype(dtype)
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
res += g_current
log_t_old = log_t_new log_t = log_t + log(scalar_abs(p)) + log_z
log_t_old_sign = log_t_new_sign log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
sign_zk *= sign_z log_g = log_t + log(scalar_abs(term))
log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
if k >= min_steps and np.abs(g_current) <= precision:
return res
warnings.warn( g_current = log_g_sign * exp(log_g) * sign_zk
f"hyp2f1_der did not converge after {k} iterations",
RuntimeWarning,
)
return np.nan
# TODO: We could implement the Euler transform to expand supported domain, as Stan does # If p==0, don't update grad and get out of while loop next
if not check_2f1_converges(a, b, c, z): grad = switch(
warnings.warn( eq(p, 0),
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}", grad,
RuntimeWarning, grad + g_current,
) )
return np.nan
return compute_grad_2f1(a, b, c, z, wrt=wrt) sign_zk *= sign_z
k += 1
def __call__(self, a, b, c, z, wrt, **kwargs): return (
# This allows wrt to be a keyword argument (grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k),
return super().__call__(a, b, c, z, wrt, **kwargs) (eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))),
)
def c_code(self, *args, **kwargs): init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k]
raise NotImplementedError() constant = [a, b, c, log_z, sign_z]
grad = _make_scalar_loop(
max_steps, init, constant, inner_loop, name="hyp2f1_grad"
)
return switch(
eq(z, 0),
0,
grad,
)
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der") # We have to pass the converges flag to interrupt the loop, as the switch is not lazy
z_is_zero = eq(z, 0)
converges = check_2f1_converges(a, b, c, z)
return switch(
z_is_zero,
0,
switch(
converges,
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
np.nan,
),
)
from contextlib import ExitStack as does_not_warn import warnings
import numpy as np import numpy as np
import pytest import pytest
...@@ -872,7 +872,30 @@ TestHyp2F1InplaceBroadcast = makeBroadcastTester( ...@@ -872,7 +872,30 @@ TestHyp2F1InplaceBroadcast = makeBroadcastTester(
) )
def test_hyp2f1_grad_stan_cases(): class TestHyp2F1Grad:
few_iters_case = (
2.0,
1.0,
2.0,
0.4,
0.4617734323582945,
0.851376039609984,
-0.4617734323582945,
2.777777777777778,
)
many_iters_case = (
3.70975,
1.0,
2.70975,
0.999696,
29369830.002773938200417693317785,
36347869.41885337,
-30843032.10697079073015067426929807,
26278034019.28811,
)
def test_hyp2f1_grad_stan_cases(self):
"""This test reuses the same test cases as in: """This test reuses the same test cases as in:
https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp
https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp
...@@ -880,12 +903,11 @@ def test_hyp2f1_grad_stan_cases(): ...@@ -880,12 +903,11 @@ def test_hyp2f1_grad_stan_cases():
Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests
""" """
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z") a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
betainc_out = at.hyp2f1(a1, a2, b1, z) hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
betainc_grad = at.grad(betainc_out, [a1, a2, b1, z]) hyp2f1_grad = at.grad(hyp2f1_out, [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], betainc_grad) f_grad = function([a1, a2, b1, z], hyp2f1_grad)
rtol = 1e-9 if config.floatX == "float64" else 1e-3
rtol = 1e-9 if config.floatX == "float64" else 2e-3
for ( for (
test_a1, test_a1,
test_a2, test_a2,
...@@ -987,26 +1009,8 @@ def test_hyp2f1_grad_stan_cases(): ...@@ -987,26 +1009,8 @@ def test_hyp2f1_grad_stan_cases():
0.0677809985598383, 0.0677809985598383,
0.865295247272367, 0.865295247272367,
), ),
( self.few_iters_case,
2.0, self.many_iters_case,
1.0,
2.0,
0.4,
0.4617734323582945,
0.851376039609984,
-0.4617734323582945,
2.777777777777778,
),
(
3.70975,
1.0,
2.70975,
0.999696,
29369830.002773938200417693317785,
36347869.41885337,
-30843032.10697079073015067426929807,
26278034019.28811,
),
# Cases where series does not converge # Cases where series does not converge
(1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf), (1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf),
(1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf), (1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf),
...@@ -1014,16 +1018,13 @@ def test_hyp2f1_grad_stan_cases(): ...@@ -1014,16 +1018,13 @@ def test_hyp2f1_grad_stan_cases():
# (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889), # (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889),
(1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889), (1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889),
): ):
expectation = ( with warnings.catch_warnings():
pytest.warns( warnings.simplefilter("error")
RuntimeWarning, match="Hyp2F1 does not meet convergence conditions" warnings.filterwarnings(
) "ignore",
if np.any( category=RuntimeWarning,
np.isnan([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]) message="divide by zero encountered in log",
)
else does_not_warn()
) )
with expectation:
result = np.array(f_grad(test_a1, test_a2, test_b1, test_z)) result = np.array(f_grad(test_a1, test_a2, test_b1, test_z))
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -1031,3 +1032,23 @@ def test_hyp2f1_grad_stan_cases(): ...@@ -1031,3 +1032,23 @@ def test_hyp2f1_grad_stan_cases():
np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]), np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]),
rtol=rtol, rtol=rtol,
) )
@pytest.mark.parametrize("case", (few_iters_case, many_iters_case))
@pytest.mark.parametrize("wrt", ("a", "all"))
def test_benchmark(self, case, wrt, benchmark):
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
hyp2f1_grad = at.grad(hyp2f1_out, wrt=a1 if wrt == "a" else [a1, a2, b1, z])
f_grad = function([a1, a2, b1, z], hyp2f1_grad)
(test_a1, test_a2, test_b1, test_z, *expected_dds) = case
result = benchmark(f_grad, test_a1, test_a2, test_b1, test_z)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
expected_result = expected_dds[0] if wrt == "a" else np.array(expected_dds)
np.testing.assert_allclose(
result,
expected_result,
rtol=rtol,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论