提交 e5de51da authored 作者: Ethan Buchman's avatar Ethan Buchman

added scipy.stats.chi2.sf op

上级 0845ddc3
#definition theano.scalar op that have their python implementation taked from scipy #definition theano.scalar op that have their python implementation taked from scipy
#as scipy is not always available, we put threat them separatly #as scipy is not always available, we treat them separatly
import numpy import numpy
from theano.scalar.basic import (UnaryScalarOp, from theano.scalar.basic import (UnaryScalarOp,
...@@ -12,7 +12,10 @@ from theano.scalar.basic import (upgrade_to_float_no_complex, ...@@ -12,7 +12,10 @@ from theano.scalar.basic import (upgrade_to_float_no_complex,
imported_scipy_special = False imported_scipy_special = False
try: try:
import scipy.special import scipy.special
import scipy.stats
imported_scipy_special = True imported_scipy_special = True
# note: the Erf functions check for this but others do not. Should they?
# Importing scipy.special may raise ValueError. # Importing scipy.special may raise ValueError.
# See http://projects.scipy.org/scipy/ticket/1739 # See http://projects.scipy.org/scipy/ticket/1739
except (ImportError, ValueError): except (ImportError, ValueError):
...@@ -279,3 +282,277 @@ DEVICE double _psi(double x){ ...@@ -279,3 +282,277 @@ DEVICE double _psi(double x){
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
psi = Psi(upgrade_to_float, name='psi') psi = Psi(upgrade_to_float, name='psi')
from theano.scalar import ScalarOp #necessary since chi2sf takes two inputs
class Chi2SF(ScalarOp):
"""
Compute (1 - chi2_cdf(x))
ie. chi2 pvalue (chi2 'survival function')
"""
@staticmethod
def st_impl(x, k):
return scipy.stats.chi2.sf(x, k)
def impl(self, x, k):
return Chi2SF.st_impl(x, k)
def c_support_code(self):
return(
"""
/*----------------------------------------------------------------------
File : gamma.c
Contents: computation of the (incomplete/regularized) gamma function
Author : Christian Borgelt
History : 2002.07.04 file created
2003.05.19 incomplete Gamma function added
2008.03.14 more incomplete Gamma functions added
2008.03.15 table of factorials and logarithms added
2008.03.17 gamma distribution functions added
----------------------------------------------------------------------*/
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <float.h>
#include <math.h>
/*----------------------------------------------------------------------
Preprocessor Definitions
----------------------------------------------------------------------*/
#define LN_BASE 2.71828182845904523536028747135 /* e */
#define SQRT_PI 1.77245385090551602729816748334 /* \sqrt(\pi) */
#define LN_PI 1.14472988584940017414342735135 /* \ln(\pi) */
#define LN_SQRT_2PI 0.918938533204672741780329736406
/* \ln(\sqrt(2\pi)) */
#define EPSILON 2.2204460492503131e-16
#define EPS_QTL 1.4901161193847656e-08
#define MAXFACT 170
#define MAXITER 1024
#define TINY (EPSILON *EPSILON *EPSILON)
/*----------------------------------------------------------------------
Table of Factorials/Gamma Values
----------------------------------------------------------------------*/
static double _facts[MAXFACT+1] = { 0 };
static double _logfs[MAXFACT+1];
static double _halfs[MAXFACT+1];
static double _loghs[MAXFACT+1];
/*----------------------------------------------------------------------
Functions
----------------------------------------------------------------------*/
static void _init (void)
{ /* --- init. factorial tables */
int i; /* loop variable */
double x = 1; /* factorial */
_facts[0] = _facts[1] = 1; /* store factorials for 0 and 1 */
_logfs[0] = _logfs[1] = 0; /* and their logarithms */
for (i = 1; ++i <= MAXFACT; ) {
_facts[i] = x *= i; /* initialize the factorial table */
_logfs[i] = log(x); /* and the table of their logarithms */
}
_halfs[0] = x = SQRT_PI; /* store Gamma(0.5) */
_loghs[0] = 0.5*LN_PI; /* and its logarithm */
for (i = 0; ++i < MAXFACT; ) {
_halfs[i] = x *= i-0.5; /* initialize the table for */
_loghs[i] = log(x); /* the Gamma function of half numbers */
} /* and the table of their logarithms */
} /* _init() */
/*--------------------------------------------------------------------*/
#if 0
double logGamma (double n)
{ /* --- compute ln(Gamma(n)) */
double s; /* = ln((n-1)!), n \in IN */
assert(n > 0); /* check the function argument */
if (_facts[0] <= 0) _init(); /* initialize the tables */
if (n < MAXFACT +1 +4 *EPSILON) {
if (fabs( n -floor( n)) < 4 *EPSILON)
return _logfs[(int)floor(n)-1];
if (fabs(2*n -floor(2*n)) < 4 *EPSILON)
return _loghs[(int)floor(n)];
} /* try to get the value from a table */
s = 1.000000000190015 /* otherwise compute it */
+ 76.18009172947146 /(n+1)
- 86.50532032941677 /(n+2)
+ 24.01409824083091 /(n+3)
- 1.231739572450155 /(n+4)
+ 0.1208650972866179e-2 /(n+5)
- 0.5395239384953e-5 /(n+6);
return (n+0.5) *log((n+5.5)/LN_BASE) +(LN_SQRT_2PI +log(s/n) -5.0);
} /* logGamma() */
#else /*--------------------------------------------------------------*/
double logGamma (double n)
{ /* --- compute ln(Gamma(n)) */
double s; /* = ln((n-1)!), n \in IN */
assert(n > 0); /* check the function argument */
if (_facts[0] <= 0) _init(); /* initialize the tables */
if (n < MAXFACT +1 +4 *EPSILON) {
if (fabs( n -floor( n)) < 4 *EPSILON)
return _logfs[(int)floor(n)-1];
if (fabs(2*n -floor(2*n)) < 4 *EPSILON)
return _loghs[(int)floor(n)];
} /* try to get the value from a table */
s = 0.99999999999980993227684700473478 /* otherwise compute it */
+ 676.520368121885098567009190444019 /(n+1)
- 1259.13921672240287047156078755283 /(n+2)
+ 771.3234287776530788486528258894 /(n+3)
- 176.61502916214059906584551354 /(n+4)
+ 12.507343278686904814458936853 /(n+5)
- 0.13857109526572011689554707 /(n+6)
+ 9.984369578019570859563e-6 /(n+7)
+ 1.50563273514931155834e-7 /(n+8);
return (n+0.5) *log((n+7.5)/LN_BASE) +(LN_SQRT_2PI +log(s/n) -7.0);
} /* logGamma() */
#endif
/*----------------------------------------------------------------------
Use Lanczos' approximation
\Gamma(n+1) = (n+\gamma+0.5)^(n+0.5)
* e^{-(n+\gamma+0.5)}
* \sqrt{2\pi}
* (c_0 +c_1/(n+1) +c_2/(n+2) +...+c_n/(n+k) +\epsilon)
and exploit the recursion \Gamma(n+1) = n *\Gamma(n) once,
i.e., compute \Gamma(n) as \Gamma(n+1) /n.
For the choices \gamma = 5, k = 6, and c_0 to c_6 as defined
in the first version, it is |\epsilon| < 2e-10 for all n > 0.
Source: W.H. Press, S.A. Teukolsky, W.T. Vetterling, and B.P. Flannery
Numerical Recipes in C - The Art of Scientific Computing
Cambridge University Press, Cambridge, United Kingdom 1992
pp. 213-214
For the choices gamma = 7, k = 8, and c_0 to c_8 as defined
in the second version, the value is slightly more accurate.
----------------------------------------------------------------------*/
double Gamma (double n)
{ /* --- compute Gamma(n) = (n-1)! */
assert(n > 0); /* check the function argument */
if (_facts[0] <= 0) _init(); /* initialize the tables */
if (n < MAXFACT +1 +4 *EPSILON) {
if (fabs( n -floor( n)) < 4 *EPSILON)
return _facts[(int)floor(n)-1];
if (fabs(2*n -floor(2*n)) < 4 *EPSILON)
return _halfs[(int)floor(n)];
} /* try to get the value from a table */
return exp(logGamma(n)); /* compute through natural logarithm */
} /* Gamma() */
/*--------------------------------------------------------------------*/
static double _series (double n, double x)
{ /* --- series approximation */
int i; /* loop variable */
double t, sum; /* buffers */
sum = t = 1/n; /* compute initial values */
for (i = MAXITER; --i >= 0; ) {
sum += t *= x/++n; /* add one term of the series */
if (fabs(t) < fabs(sum) *EPSILON) break;
} /* if term is small enough, abort */
return sum; /* return the computed factor */
} /* _series() */
/*----------------------------------------------------------------------
series approximation:
P(a,x) = \gamma(a,x)/\Gamma(a)
\gamma(a,x) = e^-x x^a \sum_{n=0}^\infty (\Gamma(a)/\Gamma(a+1+n)) x^n
Source: W.H. Press, S.A. Teukolsky, W.T. Vetterling, and B.P. Flannery
Numerical Recipes in C - The Art of Scientific Computing
Cambridge University Press, Cambridge, United Kingdom 1992
formula: pp. 216-219
The factor exp(n *log(x) -x) is added in the functions below.
----------------------------------------------------------------------*/
static double _cfrac (double n, double x)
{ /* --- continued fraction approx. */
int i; /* loop variable */
double a, b, c, d, e, f; /* buffers */
b = x+1-n; c = 1/TINY; f = d = 1/b;
for (i = 1; i < MAXITER; i++) {
a = i*(n-i); /* use Lentz's algorithm to compute */
d = a *d +(b += 2); /* consecutive approximations */
if (fabs(d) < TINY) d = TINY;
c = b +a/c;
if (fabs(c) < TINY) c = TINY;
d = 1/d; f *= e = d *c;
if (fabs(e-1) < EPSILON) break;
} /* if factor is small enough, abort */
return f; /* return the computed factor */
} /* _cfrac() */
/*----------------------------------------------------------------------
continued fraction approximation:
P(a,x) = 1 -\Gamma(a,x)/\Gamma(a)
\Gamma(a,x) = e^-x x^a (1/(x+1-a- 1(1-a)/(x+3-a- 2*(2-a)/(x+5-a- ...))))
Source: W.H. Press, S.A. Teukolsky, W.T. Vetterling, and B.P. Flannery
Numerical Recipes in C - The Art of Scientific Computing
Cambridge University Press, Cambridge, United Kingdom 1992
formula: pp. 216-219
Lentz's algorithm: p. 171
The factor exp(n *log(x) -x) is added in the functions below.
----------------------------------------------------------------------*/
double lowerGamma (double n, double x)
{ /* --- lower incomplete Gamma fn. */
assert((n > 0) && (x > 0)); /* check the function arguments */
return _series(n, x) *exp(n *log(x) -x);
} /* lowerGamma() */
/*--------------------------------------------------------------------*/
double upperGamma (double n, double x)
{ /* --- upper incomplete Gamma fn. */
assert((n > 0) && (x > 0)); /* check the function arguments */
return _cfrac(n, x) *exp(n *log(x) -x);
} /* upperGamma() */
/*--------------------------------------------------------------------*/
double GammaP (double n, double x)
{ /* --- regularized Gamma function P */
assert((n > 0) && (x >= 0)); /* check the function arguments */
if (x <= 0) return 0; /* treat x = 0 as a special case */
if (x < n+1) return _series(n, x) *exp(n *log(x) -x -logGamma(n));
return 1 -_cfrac(n, x) *exp(n *log(x) -x -logGamma(n));
} /* GammaP() */
//ebuchman: this function is equivalent to scipy.stats.chi2.sf
//it's the pvalue (survival function) of a chi2 distribution
double Chi2SF (double k, double x)
{
return 1 - GammaP(k/2., x/2.);
}
""")
def c_code(self, node, name, inp, out, sub):
x, k = inp
z, = out
if node.inputs[0].type in float_types:
return """%(z)s =
Chi2SF(%(k)s, %(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
chi2sf = Chi2SF(upgrade_to_float, name='chi2sf')
...@@ -3260,6 +3260,11 @@ def gammaln(a): ...@@ -3260,6 +3260,11 @@ def gammaln(a):
def psi(a): def psi(a):
"""derivative of log gamma function""" """derivative of log gamma function"""
@_scal_elemwise
def chi2sf(x, k):
"""chi squared survival function"""
@_scal_elemwise_with_nfunc('real', 1, -1) @_scal_elemwise_with_nfunc('real', 1, -1)
def real(z): def real(z):
......
...@@ -229,6 +229,10 @@ def gammaln_inplace(a): ...@@ -229,6 +229,10 @@ def gammaln_inplace(a):
def psi_inplace(a): def psi_inplace(a):
"""derivative of log gamma function""" """derivative of log gamma function"""
@_scal_inplace
def chi2sf_inplace(x, k):
"""chi squared survival function"""
@_scal_inplace @_scal_inplace
def second_inplace(a): def second_inplace(a):
"""Fill `a` with `b`""" """Fill `a` with `b`"""
......
...@@ -52,6 +52,7 @@ imported_scipy_special = False ...@@ -52,6 +52,7 @@ imported_scipy_special = False
mode_no_scipy = get_default_mode() mode_no_scipy = get_default_mode()
try: try:
import scipy.special import scipy.special
import scipy.stats
imported_scipy_special = True imported_scipy_special = True
except ImportError: except ImportError:
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
...@@ -1419,6 +1420,7 @@ if imported_scipy_special: ...@@ -1419,6 +1420,7 @@ if imported_scipy_special:
expected_gamma = scipy.special.gamma expected_gamma = scipy.special.gamma
expected_gammaln = scipy.special.gammaln expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi expected_psi = scipy.special.psi
expected_chi2sf = scipy.stats.chi2.sf
skip_scipy = False skip_scipy = False
else: else:
expected_erf = [] expected_erf = []
...@@ -1428,6 +1430,7 @@ else: ...@@ -1428,6 +1430,7 @@ else:
expected_gamma = [] expected_gamma = []
expected_gammaln = [] expected_gammaln = []
expected_psi = [] expected_psi = []
expected_chi2sf = []
skip_scipy = "scipy is not present" skip_scipy = "scipy is not present"
ErfTester = makeBroadcastTester( ErfTester = makeBroadcastTester(
...@@ -1547,6 +1550,32 @@ PsiInplaceTester = makeBroadcastTester( ...@@ -1547,6 +1550,32 @@ PsiInplaceTester = makeBroadcastTester(
inplace=True, inplace=True,
skip=skip_scipy) skip=skip_scipy)
'''
#chi2sf takes two inputs, a value (x) and a degrees of freedom (k).
# not sure how to deal with that here...
_good_broadcast_unary_psi = dict(
normal=(rand_ranged(1, 10, (2, 3)),),
empty=(numpy.asarray([]),),)
Chi2SFTester = makeBroadcastTester(
op=tensor.chi2sf,
expected=expected_chi2sf,
good=_good_broadcast_unary_chi2sf,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
Chi2SFInplaceTester = makeBroadcastTester(
op=inplace.chi2sf_inplace,
expected=expected_chi2sf,
good=_good_broadcast_unary_chi2sf,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy)
'''
ZerosLikeTester = makeBroadcastTester( ZerosLikeTester = makeBroadcastTester(
op=tensor.zeros_like, op=tensor.zeros_like,
expected=numpy.zeros_like, expected=numpy.zeros_like,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论