提交 3041831e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use ScalarLoop for gammainc(c) gradients

上级 cd93444e
差异被折叠。
...@@ -3,6 +3,8 @@ from contextlib import ExitStack as does_not_warn ...@@ -3,6 +3,8 @@ from contextlib import ExitStack as does_not_warn
import numpy as np import numpy as np
import pytest import pytest
from pytensor.gradient import verify_grad
scipy = pytest.importorskip("scipy") scipy = pytest.importorskip("scipy")
...@@ -11,11 +13,11 @@ from functools import partial ...@@ -11,11 +13,11 @@ from functools import partial
import scipy.special import scipy.special
import scipy.stats import scipy.stats
from pytensor import function from pytensor import function, grad
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor import inplace from pytensor.tensor import gammaincc, inplace, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.utils import ( from tests.tensor.utils import (
_good_broadcast_unary_chi2sf, _good_broadcast_unary_chi2sf,
...@@ -387,6 +389,9 @@ def test_gammainc_ddk_tabulated_values(): ...@@ -387,6 +389,9 @@ def test_gammainc_ddk_tabulated_values():
gammaincc_ddk = at.grad(gammainc_out, k) gammaincc_ddk = at.grad(gammainc_out, k)
f_grad = function([k, x], gammaincc_ddk) f_grad = function([k, x], gammaincc_ddk)
rtol = 1e-5 if config.floatX == "float64" else 1e-2
atol = 1e-10 if config.floatX == "float64" else 1e-6
for test_k, test_x, expected_ddk in ( for test_k, test_x, expected_ddk in (
(0.0001, 0, 0), # Limit condition (0.0001, 0, 0), # Limit condition
(0.0001, 0.0001, -8.62594024578651), (0.0001, 0.0001, -8.62594024578651),
...@@ -421,10 +426,27 @@ def test_gammainc_ddk_tabulated_values(): ...@@ -421,10 +426,27 @@ def test_gammainc_ddk_tabulated_values():
(19.0001, 29.7501, -0.007828749832965796), (19.0001, 29.7501, -0.007828749832965796),
): ):
np.testing.assert_allclose( np.testing.assert_allclose(
f_grad(test_k, test_x), expected_ddk, rtol=1e-5, atol=1e-14 f_grad(test_k, test_x), expected_ddk, rtol=rtol, atol=atol
) )
def test_gammaincc_ddk_performance(benchmark):
rng = np.random.default_rng(1)
k = vector("k")
x = vector("x")
out = gammaincc(k, x)
grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN")
vals = [
# Values that hit the second branch of the gradient
np.full((1000,), 3.2),
np.full((1000,), 0.01),
]
verify_grad(gammaincc, vals, rng=rng)
benchmark(grad_fn, *vals)
TestGammaUBroadcast = makeBroadcastTester( TestGammaUBroadcast = makeBroadcastTester(
op=at.gammau, op=at.gammau,
expected=expected_gammau, expected=expected_gammau,
...@@ -796,7 +818,7 @@ class TestBetaIncGrad: ...@@ -796,7 +818,7 @@ class TestBetaIncGrad:
betainc_out = at.betainc(a, b, z) betainc_out = at.betainc(a, b, z)
betainc_grad = at.grad(betainc_out, [a, b]) betainc_grad = at.grad(betainc_out, [a, b])
f_grad = function([a, b, z], betainc_grad) f_grad = function([a, b, z], betainc_grad)
decimal = 7 if config.floatX == "float64" else 5
for test_a, test_b, test_z, expected_dda, expected_ddb in ( for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04), (1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04),
(1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04), (1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04),
...@@ -806,6 +828,7 @@ class TestBetaIncGrad: ...@@ -806,6 +828,7 @@ class TestBetaIncGrad:
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
f_grad(test_a, test_b, test_z), f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb], [expected_dda, expected_ddb],
decimal=decimal,
) )
def test_beta_inc_stan_grad_combined(self): def test_beta_inc_stan_grad_combined(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论