提交 a6e79f28 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Restore Scipy-like precision for betainc

上级 2effaf15
...@@ -16,10 +16,12 @@ Copyright 1984, 1995, 2000 by Stephen L. Moshier ...@@ -16,10 +16,12 @@ Copyright 1984, 1995, 2000 by Stephen L. Moshier
#include <numpy/npy_math.h> #include <numpy/npy_math.h>
#define MINLOG -170.0 // Constants borrowed from Scipy
#define MAXLOG +170.0 // https://github.com/scipy/scipy/blob/81c53d48a290b604ec5faa34c0a7d48537b487d6/scipy/special/special/cephes/const.h#L65-L78
#define MINLOG -7.451332191019412076235E2 // log 2**-1022
#define MAXLOG 7.09782712893383996732E2 // log(DBL_MAX)
#define MAXGAM 171.624376956302725 #define MAXGAM 171.624376956302725
#define EPSILON 2.2204460492503131e-16 #define EPSILON 1.11022302462515654042e-16 // 2**-53
DEVICE static double pseries(double, double, double); DEVICE static double pseries(double, double, double);
DEVICE static double incbcf(double, double, double); DEVICE static double incbcf(double, double, double);
......
...@@ -1497,7 +1497,7 @@ class BetaInc(ScalarOp): ...@@ -1497,7 +1497,7 @@ class BetaInc(ScalarOp):
raise NotImplementedError("type not supported", type) raise NotImplementedError("type not supported", type)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
betainc = BetaInc(upgrade_to_float_no_complex, name="betainc") betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")
......
...@@ -99,12 +99,17 @@ def test_gammau_nan_c(): ...@@ -99,12 +99,17 @@ def test_gammau_nan_c():
assert np.isnan(test_func(-1, -1)) assert np.isnan(test_func(-1, -1))
def test_betainc(): @pytest.mark.parametrize("linker", ["py", "c"])
def test_betainc(linker):
a, b, x = pt.scalars("a", "b", "x") a, b, x = pt.scalars("a", "b", "x")
res = betainc(a, b, x) res = betainc(a, b, x)
test_func = function([a, b, x], res, mode=Mode("py")) test_func = function([a, b, x], res, mode=Mode(linker=linker, optimizer="fast_run"))
assert np.isclose(test_func(15, 10, 0.7), sp.betainc(15, 10, 0.7)) assert np.isclose(test_func(15, 10, 0.7), sp.betainc(15, 10, 0.7))
# Regression test for https://github.com/pymc-devs/pytensor/issues/906
if res.dtype == "float64":
assert test_func(100, 1.0, 0.1) > 0
def test_betainc_derivative_nan(): def test_betainc_derivative_nan():
a, b, x = pt.scalars("a", "b", "x") a, b, x = pt.scalars("a", "b", "x")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论