Unverified 提交 79523be3 authored 作者: ricardoV94's avatar ricardoV94 提交者: GitHub

Replace C assertions with NAN in `gamma.c` and add missing inplace gamma functions

上级 0344f1a8
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
#ifndef _ISOC99_SOURCE #ifndef _ISOC99_SOURCE
#define _ISOC99_SOURCE #define _ISOC99_SOURCE
#endif /* needed for function log1p() */ #endif /* needed for function log1p() */
#include <assert.h>
#include <float.h> #include <float.h>
#include <math.h> #include <math.h>
#include <numpy/npy_math.h>
/*---------------------------------------------------------------------- /*----------------------------------------------------------------------
Preprocessor Definitions Preprocessor Definitions
...@@ -84,7 +84,7 @@ DEVICE double logGamma (double n) ...@@ -84,7 +84,7 @@ DEVICE double logGamma (double n)
{ /* --- compute ln(Gamma(n)) */ { /* --- compute ln(Gamma(n)) */
double s; /* = ln((n-1)!), n \in IN */ double s; /* = ln((n-1)!), n \in IN */
assert(n > 0); /* check the function argument */ if (n <= 0) return NPY_NAN; /* check the function arguments */
if (_facts[0] <= 0) _init(); /* initialize the tables */ if (_facts[0] <= 0) _init(); /* initialize the tables */
if (n < MAXFACT +1 +4 *EPSILON) { if (n < MAXFACT +1 +4 *EPSILON) {
if (fabs( n -floor( n)) < 4 *EPSILON) if (fabs( n -floor( n)) < 4 *EPSILON)
...@@ -127,7 +127,7 @@ in the second version, the value is slightly more accurate. ...@@ -127,7 +127,7 @@ in the second version, the value is slightly more accurate.
DEVICE double Gamma (double n) DEVICE double Gamma (double n)
{ /* --- compute Gamma(n) = (n-1)! */ { /* --- compute Gamma(n) = (n-1)! */
assert(n > 0); /* check the function argument */ if (n <= 0) return NPY_NAN; /* check the function arguments */
if (_facts[0] <= 0) _init(); /* initialize the tables */ if (_facts[0] <= 0) _init(); /* initialize the tables */
if (n < MAXFACT +1 +4 *EPSILON) { if (n < MAXFACT +1 +4 *EPSILON) {
if (fabs( n -floor( n)) < 4 *EPSILON) if (fabs( n -floor( n)) < 4 *EPSILON)
...@@ -200,7 +200,7 @@ The factor exp(n *log(x) -x) is added in the functions below. ...@@ -200,7 +200,7 @@ The factor exp(n *log(x) -x) is added in the functions below.
DEVICE double lowerGamma (double n, double x) DEVICE double lowerGamma (double n, double x)
{ /* --- lower incomplete Gamma fn. */ { /* --- lower incomplete Gamma fn. */
assert((n > 0) && (x > 0)); /* check the function arguments */ if ((n <= 0) || (x <= 0)) return NPY_NAN; /* check the function arguments */
return _series(n, x) *exp(n *log(x) -x); return _series(n, x) *exp(n *log(x) -x);
} /* lowerGamma() */ } /* lowerGamma() */
...@@ -208,7 +208,7 @@ DEVICE double lowerGamma (double n, double x) ...@@ -208,7 +208,7 @@ DEVICE double lowerGamma (double n, double x)
DEVICE double upperGamma (double n, double x) DEVICE double upperGamma (double n, double x)
{ /* --- upper incomplete Gamma fn. */ { /* --- upper incomplete Gamma fn. */
assert((n > 0) && (x > 0)); /* check the function arguments */ if ((n <= 0) || (x <= 0)) return NPY_NAN; /* check the function arguments */
return _cfrac(n, x) *exp(n *log(x) -x); return _cfrac(n, x) *exp(n *log(x) -x);
} /* upperGamma() */ } /* upperGamma() */
...@@ -216,7 +216,7 @@ DEVICE double upperGamma (double n, double x) ...@@ -216,7 +216,7 @@ DEVICE double upperGamma (double n, double x)
DEVICE double GammaP (double n, double x) DEVICE double GammaP (double n, double x)
{ /* --- regularized Gamma function P */ { /* --- regularized Gamma function P */
assert((n > 0) && (x >= 0)); /* check the function arguments */ if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */
if (x <= 0) return 0; /* treat x = 0 as a special case */ if (x <= 0) return 0; /* treat x = 0 as a special case */
if (x < n+1) return _series(n, x) *exp(n *log(x) -x -logGamma(n)); if (x < n+1) return _series(n, x) *exp(n *log(x) -x -logGamma(n));
return 1 -_cfrac(n, x) *exp(n *log(x) -x -logGamma(n)); return 1 -_cfrac(n, x) *exp(n *log(x) -x -logGamma(n));
...@@ -226,7 +226,7 @@ DEVICE double GammaP (double n, double x) ...@@ -226,7 +226,7 @@ DEVICE double GammaP (double n, double x)
DEVICE double GammaQ (double n, double x) DEVICE double GammaQ (double n, double x)
{ /* --- regularized Gamma function Q */ { /* --- regularized Gamma function Q */
assert((n > 0) && (x >= 0)); /* check the function arguments */ if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */
if (x <= 0) return 1; /* treat x = 0 as a special case */ if (x <= 0) return 1; /* treat x = 0 as a special case */
if (x < n+1) return 1 -_series(n, x) *exp(n *log(x) -x -logGamma(n)); if (x < n+1) return 1 -_series(n, x) *exp(n *log(x) -x -logGamma(n));
return _cfrac(n, x) *exp(n *log(x) -x -logGamma(n)); return _cfrac(n, x) *exp(n *log(x) -x -logGamma(n));
...@@ -240,7 +240,7 @@ function (cdf) of a chi^2 distribution with k degrees of freedom. ...@@ -240,7 +240,7 @@ function (cdf) of a chi^2 distribution with k degrees of freedom.
DEVICE double Gammapdf (double x, double k, double theta) DEVICE double Gammapdf (double x, double k, double theta)
{ /* --- probability density function */ { /* --- probability density function */
assert((k > 0) && (theta > 0)); if ((k <= 0) || (theta <= 0)) return NPY_NAN; /* check the function arguments */
if (x < 0) return 0; /* support is non-negative x */ if (x < 0) return 0; /* support is non-negative x */
if (x <= 0) return (k == 1) ? 1/theta : 0; if (x <= 0) return (k == 1) ? 1/theta : 0;
if (k == 1) return exp(-x/theta) /theta; if (k == 1) return exp(-x/theta) /theta;
...@@ -252,7 +252,7 @@ double unitqtlP (double prob) ...@@ -252,7 +252,7 @@ double unitqtlP (double prob)
{ /* --- quantile of normal distrib. */ { /* --- quantile of normal distrib. */
double p, x; /* with mean 0 and variance 1 */ double p, x; /* with mean 0 and variance 1 */
assert((prob >= 0) && (prob <= 1)); /* check the function argument */ if ((prob < 0) || (prob > 1)) return NPY_NAN; /* check the function arguments */
if (prob >= 1.0) return DBL_MAX; /* check for limiting values */ if (prob >= 1.0) return DBL_MAX; /* check for limiting values */
if (prob <= 0.0) return -DBL_MAX; /* and return extrema */ if (prob <= 0.0) return -DBL_MAX; /* and return extrema */
p = prob -0.5; p = prob -0.5;
...@@ -323,8 +323,8 @@ DEVICE double GammaqtlP (double prob, double k, double theta) ...@@ -323,8 +323,8 @@ DEVICE double GammaqtlP (double prob, double k, double theta)
int n = 0; /* loop variable */ int n = 0; /* loop variable */
double x, f, a, d, dx, dp; /* buffers */ double x, f, a, d, dx, dp; /* buffers */
assert((k > 0) && (theta > 0) /* check the function arguments */ /* check the function arguments */
&& (prob >= 0) && (prob <= 1)); if ((k <= 0) || (theta <= 0) || (prob < 0) || (prob > 1)) return NPY_NAN;
if (prob >= 1.0) return DBL_MAX; if (prob >= 1.0) return DBL_MAX;
if (prob <= 0.0) return 0; /* handle limiting values */ if (prob <= 0.0) return 0; /* handle limiting values */
if (prob < 0.05) x = exp(logGamma(k) +log(prob) /k); if (prob < 0.05) x = exp(logGamma(k) +log(prob) /k);
...@@ -355,8 +355,8 @@ DEVICE double GammaqtlQ (double prob, double k, double theta) ...@@ -355,8 +355,8 @@ DEVICE double GammaqtlQ (double prob, double k, double theta)
int n = 0; /* loop variable */ int n = 0; /* loop variable */
double x, f, a, d, dx, dp; /* buffers */ double x, f, a, d, dx, dp; /* buffers */
assert((k > 0) && (theta > 0) /* check the function arguments */ /* check the function arguments */
&& (prob >= 0) && (prob <= 1)); if ((k <= 0) || (theta <= 0) || (prob < 0) || (prob > 1)) return NPY_NAN;
if (prob <= 0.0) return DBL_MAX; if (prob <= 0.0) return DBL_MAX;
if (prob >= 1.0) return 0; /* handle limiting values */ if (prob >= 1.0) return 0; /* handle limiting values */
if (prob < 0.05) x = logGamma(k) -log(prob); if (prob < 0.05) x = logGamma(k) -log(prob);
......
...@@ -258,6 +258,26 @@ def chi2sf_inplace(x, k): ...@@ -258,6 +258,26 @@ def chi2sf_inplace(x, k):
"""chi squared survival function""" """chi squared survival function"""
@scalar_elemwise
def gammainc_inplace(k, x):
"""regularized lower gamma function (P)"""
@scalar_elemwise
def gammaincc_inplace(k, x):
"""regularized upper gamma function (Q)"""
@scalar_elemwise
def gammau_inplace(k, x):
"""upper incomplete gamma function"""
@scalar_elemwise
def gammal_inplace(k, x):
"""lower incomplete gamma function"""
@scalar_elemwise @scalar_elemwise
def j0_inplace(x): def j0_inplace(x):
"""Bessel function of the first kind of order 0.""" """Bessel function of the first kind of order 0."""
......
import numpy as np
import aesara.tensor as aet
from aesara.graph.fg import FunctionGraph
from aesara.link.c.basic import CLinker
from aesara.scalar.basic_scipy import gammainc, gammaincc, gammal, gammau
def test_gammainc_nan():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammainc(x1, x2)
test_func = CLinker().accept(FunctionGraph([x1, x2], [y])).make_function()
assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1))
def test_gammaincc_nan():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammaincc(x1, x2)
test_func = CLinker().accept(FunctionGraph([x1, x2], [y])).make_function()
assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1))
def test_gammal_nan():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammal(x1, x2)
test_func = CLinker().accept(FunctionGraph([x1, x2], [y])).make_function()
assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1))
def test_gammau_nan():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammau(x1, x2)
test_func = CLinker().accept(FunctionGraph([x1, x2], [y])).make_function()
assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1))
...@@ -39,6 +39,14 @@ except ImportError: ...@@ -39,6 +39,14 @@ except ImportError:
mode_no_scipy = "FAST_RUN" mode_no_scipy = "FAST_RUN"
def scipy_special_gammau(k, x):
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
def scipy_special_gammal(k, x):
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
# We can't test it if scipy is not installed! # We can't test it if scipy is not installed!
# Precomputing the result is brittle(it have been broken!) # Precomputing the result is brittle(it have been broken!)
# As if we do any modification to random number here, # As if we do any modification to random number here,
...@@ -53,14 +61,18 @@ if imported_scipy_special: ...@@ -53,14 +61,18 @@ if imported_scipy_special:
expected_psi = scipy.special.psi expected_psi = scipy.special.psi
expected_tri_gamma = partial(scipy.special.polygamma, 1) expected_tri_gamma = partial(scipy.special.polygamma, 1)
expected_chi2sf = scipy.stats.chi2.sf expected_chi2sf = scipy.stats.chi2.sf
expected_gammainc = scipy.special.gammainc
expected_gammaincc = scipy.special.gammaincc
expected_gammau = scipy_special_gammau
expected_gammal = scipy_special_gammal
expected_j0 = scipy.special.j0 expected_j0 = scipy.special.j0
expected_j1 = scipy.special.j1 expected_j1 = scipy.special.j1
expected_jv = scipy.special.jv expected_jv = scipy.special.jv
expected_i0 = scipy.special.i0 expected_i0 = scipy.special.i0
expected_i1 = scipy.special.i1 expected_i1 = scipy.special.i1
expected_iv = scipy.special.iv expected_iv = scipy.special.iv
skip_scipy = False
expected_erfcx = scipy.special.erfcx expected_erfcx = scipy.special.erfcx
skip_scipy = False
else: else:
expected_erf = [] expected_erf = []
expected_erfc = [] expected_erfc = []
...@@ -72,6 +84,10 @@ else: ...@@ -72,6 +84,10 @@ else:
expected_psi = [] expected_psi = []
expected_tri_gamma = [] expected_tri_gamma = []
expected_chi2sf = [] expected_chi2sf = []
expected_gammainc = []
expected_gammaincc = []
expected_gammau = []
expected_gammal = []
expected_j0 = [] expected_j0 = []
expected_j1 = [] expected_j1 = []
expected_jv = [] expected_jv = []
...@@ -281,6 +297,100 @@ TestChi2SFInplaceBroadcast = makeBroadcastTester( ...@@ -281,6 +297,100 @@ TestChi2SFInplaceBroadcast = makeBroadcastTester(
name="Chi2SF", name="Chi2SF",
) )
_good_broadcast_binary_gamma = dict(
normal=(rand_ranged(1e-2, 10, (2, 3)), rand_ranged(1e-2, 10, (2, 3))),
empty=(np.asarray([], dtype=config.floatX), np.asarray([], dtype=config.floatX)),
int=(randint_ranged(1, 10, (2, 3)), randint_ranged(1, 10, (2, 3))),
uint8=(
randint_ranged(1, 6, (2, 3)).astype("uint8"),
randint_ranged(1, 6, (2, 3)).astype("uint8"),
),
uint16=(
randint_ranged(1, 10, (2, 3)).astype("uint16"),
randint_ranged(1, 10, (2, 3)).astype("uint16"),
),
uint64=(
randint_ranged(1, 10, (2, 3)).astype("uint64"),
randint_ranged(1, 10, (2, 3)).astype("uint64"),
),
)
TestGammaIncBroadcast = makeBroadcastTester(
op=aet.gammainc,
expected=expected_gammainc,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
skip=skip_scipy,
)
TestGammaIncInplaceBroadcast = makeBroadcastTester(
op=inplace.gammainc_inplace,
expected=expected_gammainc,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy,
)
TestGammaInccBroadcast = makeBroadcastTester(
op=aet.gammaincc,
expected=expected_gammaincc,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
skip=skip_scipy,
)
TestGammaInccInplaceBroadcast = makeBroadcastTester(
op=inplace.gammaincc_inplace,
expected=expected_gammaincc,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy,
)
TestGammaUBroadcast = makeBroadcastTester(
op=aet.gammau,
expected=expected_gammau,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
skip=skip_scipy,
)
TestGammaUInplaceBroadcast = makeBroadcastTester(
op=inplace.gammau_inplace,
expected=expected_gammau,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy,
)
TestGammaLBroadcast = makeBroadcastTester(
op=aet.gammal,
expected=expected_gammal,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
skip=skip_scipy,
)
TestGammaLInplaceBroadcast = makeBroadcastTester(
op=inplace.gammal_inplace,
expected=expected_gammal,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy,
)
_good_broadcast_unary_bessel = dict( _good_broadcast_unary_bessel = dict(
normal=(rand_ranged(-10, 10, (2, 3)),), normal=(rand_ranged(-10, 10, (2, 3)),),
empty=(np.asarray([], dtype=config.floatX),), empty=(np.asarray([], dtype=config.floatX),),
......
...@@ -265,11 +265,30 @@ TestMulBroadcast = makeBroadcastTester( ...@@ -265,11 +265,30 @@ TestMulBroadcast = makeBroadcastTester(
), ),
) )
# Values are fixed, because the gradient evaluation in TestModBroadcast often
# fails when the inputs are close to each other (due to gradient discontinuity).
# fmt: off
_grad_broadcast_div_mod_normal = dict( _grad_broadcast_div_mod_normal = dict(
same_shapes=(rand(2, 3), rand_nonzero((2, 3))), same_shapes=(
scalar=(rand(2, 3), rand_nonzero((1, 1))), np.array([[-0.51157823, 0.02560825, -0.7482302], [0.05923786, -0.21001006, -0.66742722]]),
row=(rand(2, 3), rand_nonzero((1, 3))), np.array([[-0.02250197, -0.32979461, 0.32081774], [0.36419213, -0.54073201, 0.8932643]])
column=(rand(2, 3), rand_nonzero((2, 1))), ),
scalar=(
np.array([[0.32390696, -0.77305276, -0.66302977], [0.8214372, -0.31612823, -0.06294127]]),
np.array([[-0.86904352]])
),
row=(
np.array([[0.89763688, -0.09403658, 0.05847774], [-0.00694876, -0.08999577, 0.19857154]]),
np.array([[-0.47662978, 0.72692131, -0.18250251]])
),
column=(
np.array([[0.04506636, 0.05725927, -0.94947897], [0.39868416, -0.12655465, -0.87068554]]),
np.array([[-0.39040176], [0.76164576]])
),
# same_shapes=(rand(2, 3), rand_nonzero((2, 3))),
# scalar=(rand(2, 3), rand_nonzero((1, 1))),
# row=(rand(2, 3), rand_nonzero((1, 3))),
# column=(rand(2, 3), rand_nonzero((2, 1))),
# complex1=(randcomplex(2, 3), randcomplex_nonzero((2, 3))), # complex1=(randcomplex(2, 3), randcomplex_nonzero((2, 3))),
# complex2=(randcomplex(2, 3), rand_nonzero((2, 3))), # complex2=(randcomplex(2, 3), rand_nonzero((2, 3))),
# complex3=(rand(2, 3), randcomplex_nonzero((2, 3))), # complex3=(rand(2, 3), randcomplex_nonzero((2, 3))),
...@@ -278,6 +297,7 @@ _grad_broadcast_div_mod_normal = dict( ...@@ -278,6 +297,7 @@ _grad_broadcast_div_mod_normal = dict(
# empty1=(np.asarray([]), np.asarray([1.])), # empty1=(np.asarray([]), np.asarray([1.])),
# empty2=(np.asarray([0]), np.asarray([])), # empty2=(np.asarray([0]), np.asarray([])),
) )
# fmt: on
TestTrueDivBroadcast = makeBroadcastTester( TestTrueDivBroadcast = makeBroadcastTester(
op=true_div, op=true_div,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论