提交 426686fe authored 作者: ebuchman's avatar ebuchman

Merge pull request #1 from nouiz/ebuchman-chi2sf/master

Finish chi2 tests and GPU code.
......@@ -289,6 +289,7 @@ DEVICE double _psi(double x){
return hash(type(self))
psi = Psi(upgrade_to_float, name='psi')
class Chi2SF(BinaryScalarOp):
"""
Compute (1 - chi2_cdf(x))
......@@ -298,11 +299,13 @@ class Chi2SF(BinaryScalarOp):
@staticmethod
def st_impl(x, k):
return scipy.stats.chi2.sf(x, k)
def impl(self, x, k):
if imported_scipy_special:
return Chi2SF.st_impl(x, k)
else:
super(Chi2SF, self).impl(x, k)
def c_support_code(self):
return(
"""
......@@ -350,15 +353,15 @@ class Chi2SF(BinaryScalarOp):
/*----------------------------------------------------------------------
Table of Factorials/Gamma Values
----------------------------------------------------------------------*/
static double _facts[MAXFACT+1] = { 0 };
static double _logfs[MAXFACT+1];
static double _halfs[MAXFACT+1];
static double _loghs[MAXFACT+1];
DEVICE static double _facts[MAXFACT+1] = { 0 };
DEVICE static double _logfs[MAXFACT+1];
DEVICE static double _halfs[MAXFACT+1];
DEVICE static double _loghs[MAXFACT+1];
/*----------------------------------------------------------------------
Functions
----------------------------------------------------------------------*/
static void _init (void)
DEVICE static void _init (void)
{ /* --- init. factorial tables */
int i; /* loop variable */
double x = 1; /* factorial */
......@@ -404,7 +407,7 @@ class Chi2SF(BinaryScalarOp):
#else /*--------------------------------------------------------------*/
double logGamma (double n)
DEVICE double logGamma (double n)
{ /* --- compute ln(Gamma(n)) */
double s; /* = ln((n-1)!), n \in IN */
......@@ -450,7 +453,7 @@ class Chi2SF(BinaryScalarOp):
in the second version, the value is slightly more accurate.
----------------------------------------------------------------------*/
double Gamma (double n)
DEVICE double Gamma (double n)
{ /* --- compute Gamma(n) = (n-1)! */
assert(n > 0); /* check the function argument */
if (_facts[0] <= 0) _init(); /* initialize the tables */
......@@ -465,7 +468,7 @@ class Chi2SF(BinaryScalarOp):
/*--------------------------------------------------------------------*/
static double _series (double n, double x)
DEVICE static double _series (double n, double x)
{ /* --- series approximation */
int i; /* loop variable */
double t, sum; /* buffers */
......@@ -491,7 +494,7 @@ class Chi2SF(BinaryScalarOp):
The factor exp(n *log(x) -x) is added in the functions below.
----------------------------------------------------------------------*/
static double _cfrac (double n, double x)
DEVICE static double _cfrac (double n, double x)
{ /* --- continued fraction approx. */
int i; /* loop variable */
double a, b, c, d, e, f; /* buffers */
......@@ -523,7 +526,7 @@ class Chi2SF(BinaryScalarOp):
The factor exp(n *log(x) -x) is added in the functions below.
----------------------------------------------------------------------*/
double lowerGamma (double n, double x)
DEVICE double lowerGamma (double n, double x)
{ /* --- lower incomplete Gamma fn. */
assert((n > 0) && (x > 0)); /* check the function arguments */
return _series(n, x) *exp(n *log(x) -x);
......@@ -531,7 +534,7 @@ class Chi2SF(BinaryScalarOp):
/*--------------------------------------------------------------------*/
double upperGamma (double n, double x)
DEVICE double upperGamma (double n, double x)
{ /* --- upper incomplete Gamma fn. */
assert((n > 0) && (x > 0)); /* check the function arguments */
return _cfrac(n, x) *exp(n *log(x) -x);
......@@ -539,8 +542,7 @@ class Chi2SF(BinaryScalarOp):
/*--------------------------------------------------------------------*/
double GammaP (double n, double x)
DEVICE double GammaP (double n, double x)
{ /* --- regularized Gamma function P */
assert((n > 0) && (x >= 0)); /* check the function arguments */
if (x <= 0) return 0; /* treat x = 0 as a special case */
......@@ -555,21 +557,21 @@ class Chi2SF(BinaryScalarOp):
{
return 1 - GammaP(k/2., x/2.);
}
#endif
""")
def c_code(self, node, name, inp, out, sub):
x, k = inp
z, = out
if node.inputs[0].type in float_types:
dtype = z.dtype
dtype = 'npy_' + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s)Chi2SF(%(k)s, %(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论