提交 3041831e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for gammainc(c) gradients

上级 cd93444e
差异被折叠。
......@@ -3,6 +3,8 @@ from contextlib import ExitStack as does_not_warn
import numpy as np
import pytest
from pytensor.gradient import verify_grad
scipy = pytest.importorskip("scipy")
......@@ -11,11 +13,11 @@ from functools import partial
import scipy.special
import scipy.stats
from pytensor import function
from pytensor import function, grad
from pytensor import tensor as at
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor import inplace
from pytensor.tensor import gammaincc, inplace, vector
from tests import unittest_tools as utt
from tests.tensor.utils import (
_good_broadcast_unary_chi2sf,
......@@ -387,6 +389,9 @@ def test_gammainc_ddk_tabulated_values():
gammaincc_ddk = at.grad(gammainc_out, k)
f_grad = function([k, x], gammaincc_ddk)
rtol = 1e-5 if config.floatX == "float64" else 1e-2
atol = 1e-10 if config.floatX == "float64" else 1e-6
for test_k, test_x, expected_ddk in (
(0.0001, 0, 0), # Limit condition
(0.0001, 0.0001, -8.62594024578651),
......@@ -421,10 +426,27 @@ def test_gammainc_ddk_tabulated_values():
(19.0001, 29.7501, -0.007828749832965796),
):
np.testing.assert_allclose(
f_grad(test_k, test_x), expected_ddk, rtol=1e-5, atol=1e-14
f_grad(test_k, test_x), expected_ddk, rtol=rtol, atol=atol
)
def test_gammaincc_ddk_performance(benchmark):
rng = np.random.default_rng(1)
k = vector("k")
x = vector("x")
out = gammaincc(k, x)
grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN")
vals = [
# Values that hit the second branch of the gradient
np.full((1000,), 3.2),
np.full((1000,), 0.01),
]
verify_grad(gammaincc, vals, rng=rng)
benchmark(grad_fn, *vals)
TestGammaUBroadcast = makeBroadcastTester(
op=at.gammau,
expected=expected_gammau,
......@@ -796,7 +818,7 @@ class TestBetaIncGrad:
betainc_out = at.betainc(a, b, z)
betainc_grad = at.grad(betainc_out, [a, b])
f_grad = function([a, b, z], betainc_grad)
decimal = 7 if config.floatX == "float64" else 5
for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04),
(1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04),
......@@ -806,6 +828,7 @@ class TestBetaIncGrad:
np.testing.assert_almost_equal(
f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb],
decimal=decimal,
)
def test_beta_inc_stan_grad_combined(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论