Unverified 提交 d8868ccb authored 作者: Adriano M. Yoshino's avatar Adriano M. Yoshino 提交者: GitHub

Add inf special cases to gamma.c function (#634)

上级 453fb4d2
...@@ -218,6 +218,11 @@ DEVICE double GammaP (double n, double x) ...@@ -218,6 +218,11 @@ DEVICE double GammaP (double n, double x)
{ /* --- regularized Gamma function P */ { /* --- regularized Gamma function P */
if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */ if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */
if (x <= 0) return 0; /* treat x = 0 as a special case */ if (x <= 0) return 0; /* treat x = 0 as a special case */
if (isinf(n)) {
if (isinf(x)) return NPY_NAN;
return 0;
}
if (isinf(x)) return 1;
if (x < n+1) return _series(n, x) *exp(n *log(x) -x -logGamma(n)); if (x < n+1) return _series(n, x) *exp(n *log(x) -x -logGamma(n));
return 1 -_cfrac(n, x) *exp(n *log(x) -x -logGamma(n)); return 1 -_cfrac(n, x) *exp(n *log(x) -x -logGamma(n));
} /* GammaP() */ } /* GammaP() */
...@@ -228,6 +233,11 @@ DEVICE double GammaQ (double n, double x) ...@@ -228,6 +233,11 @@ DEVICE double GammaQ (double n, double x)
{ /* --- regularized Gamma function Q */ { /* --- regularized Gamma function Q */
if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */ if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */
if (x <= 0) return 1; /* treat x = 0 as a special case */ if (x <= 0) return 1; /* treat x = 0 as a special case */
if (isinf(n)) {
if (isinf(x)) return NPY_NAN;
return 1;
}
if (isinf(x)) return 0;
if (x < n+1) return 1 -_series(n, x) *exp(n *log(x) -x -logGamma(n)); if (x < n+1) return 1 -_series(n, x) *exp(n *log(x) -x -logGamma(n));
return _cfrac(n, x) *exp(n *log(x) -x -logGamma(n)); return _cfrac(n, x) *exp(n *log(x) -x -logGamma(n));
} /* GammaQ() */ } /* GammaQ() */
......
...@@ -631,6 +631,13 @@ class Chi2SF(BinaryScalarOp): ...@@ -631,6 +631,13 @@ class Chi2SF(BinaryScalarOp):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2, *v)
else:
return v
chi2sf = Chi2SF(upgrade_to_float64, name="chi2sf") chi2sf = Chi2SF(upgrade_to_float64, name="chi2sf")
...@@ -677,6 +684,13 @@ class GammaInc(BinaryScalarOp): ...@@ -677,6 +684,13 @@ class GammaInc(BinaryScalarOp):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2, *v)
else:
return v
gammainc = GammaInc(upgrade_to_float, name="gammainc") gammainc = GammaInc(upgrade_to_float, name="gammainc")
...@@ -723,6 +737,13 @@ class GammaIncC(BinaryScalarOp): ...@@ -723,6 +737,13 @@ class GammaIncC(BinaryScalarOp):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2, *v)
else:
return v
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
......
...@@ -41,6 +41,16 @@ def test_gammainc_nan_c(): ...@@ -41,6 +41,16 @@ def test_gammainc_nan_c():
assert np.isnan(test_func(-1, -1)) assert np.isnan(test_func(-1, -1))
def test_gammainc_inf_c():
x1 = pt.dscalar()
x2 = pt.dscalar()
y = gammainc(x1, x2)
test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
assert np.isclose(test_func(np.inf, 1), sp.gammainc(np.inf, 1))
assert np.isclose(test_func(1, np.inf), sp.gammainc(1, np.inf))
assert np.isnan(test_func(np.inf, np.inf))
def test_gammaincc_python(): def test_gammaincc_python():
x1 = pt.dscalar() x1 = pt.dscalar()
x2 = pt.dscalar() x2 = pt.dscalar()
...@@ -59,6 +69,16 @@ def test_gammaincc_nan_c(): ...@@ -59,6 +69,16 @@ def test_gammaincc_nan_c():
assert np.isnan(test_func(-1, -1)) assert np.isnan(test_func(-1, -1))
def test_gammaincc_inf_c():
x1 = pt.dscalar()
x2 = pt.dscalar()
y = gammaincc(x1, x2)
test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
assert np.isclose(test_func(np.inf, 1), sp.gammaincc(np.inf, 1))
assert np.isclose(test_func(1, np.inf), sp.gammaincc(1, np.inf))
assert np.isnan(test_func(np.inf, np.inf))
def test_gammal_nan_c(): def test_gammal_nan_c():
x1 = pt.dscalar() x1 = pt.dscalar()
x2 = pt.dscalar() x2 = pt.dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论