提交 426686fe authored 作者: ebuchman's avatar ebuchman

Merge pull request #1 from nouiz/ebuchman-chi2sf/master

Finish chi2 tests and GPU code.
...@@ -289,6 +289,7 @@ DEVICE double _psi(double x){ ...@@ -289,6 +289,7 @@ DEVICE double _psi(double x){
return hash(type(self)) return hash(type(self))
psi = Psi(upgrade_to_float, name='psi') psi = Psi(upgrade_to_float, name='psi')
class Chi2SF(BinaryScalarOp): class Chi2SF(BinaryScalarOp):
""" """
Compute (1 - chi2_cdf(x)) Compute (1 - chi2_cdf(x))
...@@ -298,11 +299,13 @@ class Chi2SF(BinaryScalarOp): ...@@ -298,11 +299,13 @@ class Chi2SF(BinaryScalarOp):
@staticmethod @staticmethod
def st_impl(x, k): def st_impl(x, k):
return scipy.stats.chi2.sf(x, k) return scipy.stats.chi2.sf(x, k)
def impl(self, x, k): def impl(self, x, k):
if imported_scipy_special: if imported_scipy_special:
return Chi2SF.st_impl(x, k) return Chi2SF.st_impl(x, k)
else: else:
super(Chi2SF, self).impl(x, k) super(Chi2SF, self).impl(x, k)
def c_support_code(self): def c_support_code(self):
return( return(
""" """
...@@ -350,15 +353,15 @@ class Chi2SF(BinaryScalarOp): ...@@ -350,15 +353,15 @@ class Chi2SF(BinaryScalarOp):
/*---------------------------------------------------------------------- /*----------------------------------------------------------------------
Table of Factorials/Gamma Values Table of Factorials/Gamma Values
----------------------------------------------------------------------*/ ----------------------------------------------------------------------*/
static double _facts[MAXFACT+1] = { 0 }; DEVICE static double _facts[MAXFACT+1] = { 0 };
static double _logfs[MAXFACT+1]; DEVICE static double _logfs[MAXFACT+1];
static double _halfs[MAXFACT+1]; DEVICE static double _halfs[MAXFACT+1];
static double _loghs[MAXFACT+1]; DEVICE static double _loghs[MAXFACT+1];
/*---------------------------------------------------------------------- /*----------------------------------------------------------------------
Functions Functions
----------------------------------------------------------------------*/ ----------------------------------------------------------------------*/
static void _init (void) DEVICE static void _init (void)
{ /* --- init. factorial tables */ { /* --- init. factorial tables */
int i; /* loop variable */ int i; /* loop variable */
double x = 1; /* factorial */ double x = 1; /* factorial */
...@@ -404,7 +407,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -404,7 +407,7 @@ class Chi2SF(BinaryScalarOp):
#else /*--------------------------------------------------------------*/ #else /*--------------------------------------------------------------*/
double logGamma (double n) DEVICE double logGamma (double n)
{ /* --- compute ln(Gamma(n)) */ { /* --- compute ln(Gamma(n)) */
double s; /* = ln((n-1)!), n \in IN */ double s; /* = ln((n-1)!), n \in IN */
...@@ -450,7 +453,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -450,7 +453,7 @@ class Chi2SF(BinaryScalarOp):
in the second version, the value is slightly more accurate. in the second version, the value is slightly more accurate.
----------------------------------------------------------------------*/ ----------------------------------------------------------------------*/
double Gamma (double n) DEVICE double Gamma (double n)
{ /* --- compute Gamma(n) = (n-1)! */ { /* --- compute Gamma(n) = (n-1)! */
assert(n > 0); /* check the function argument */ assert(n > 0); /* check the function argument */
if (_facts[0] <= 0) _init(); /* initialize the tables */ if (_facts[0] <= 0) _init(); /* initialize the tables */
...@@ -465,7 +468,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -465,7 +468,7 @@ class Chi2SF(BinaryScalarOp):
/*--------------------------------------------------------------------*/ /*--------------------------------------------------------------------*/
static double _series (double n, double x) DEVICE static double _series (double n, double x)
{ /* --- series approximation */ { /* --- series approximation */
int i; /* loop variable */ int i; /* loop variable */
double t, sum; /* buffers */ double t, sum; /* buffers */
...@@ -491,7 +494,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -491,7 +494,7 @@ class Chi2SF(BinaryScalarOp):
The factor exp(n *log(x) -x) is added in the functions below. The factor exp(n *log(x) -x) is added in the functions below.
----------------------------------------------------------------------*/ ----------------------------------------------------------------------*/
static double _cfrac (double n, double x) DEVICE static double _cfrac (double n, double x)
{ /* --- continued fraction approx. */ { /* --- continued fraction approx. */
int i; /* loop variable */ int i; /* loop variable */
double a, b, c, d, e, f; /* buffers */ double a, b, c, d, e, f; /* buffers */
...@@ -523,7 +526,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -523,7 +526,7 @@ class Chi2SF(BinaryScalarOp):
The factor exp(n *log(x) -x) is added in the functions below. The factor exp(n *log(x) -x) is added in the functions below.
----------------------------------------------------------------------*/ ----------------------------------------------------------------------*/
double lowerGamma (double n, double x) DEVICE double lowerGamma (double n, double x)
{ /* --- lower incomplete Gamma fn. */ { /* --- lower incomplete Gamma fn. */
assert((n > 0) && (x > 0)); /* check the function arguments */ assert((n > 0) && (x > 0)); /* check the function arguments */
return _series(n, x) *exp(n *log(x) -x); return _series(n, x) *exp(n *log(x) -x);
...@@ -531,7 +534,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -531,7 +534,7 @@ class Chi2SF(BinaryScalarOp):
/*--------------------------------------------------------------------*/ /*--------------------------------------------------------------------*/
double upperGamma (double n, double x) DEVICE double upperGamma (double n, double x)
{ /* --- upper incomplete Gamma fn. */ { /* --- upper incomplete Gamma fn. */
assert((n > 0) && (x > 0)); /* check the function arguments */ assert((n > 0) && (x > 0)); /* check the function arguments */
return _cfrac(n, x) *exp(n *log(x) -x); return _cfrac(n, x) *exp(n *log(x) -x);
...@@ -539,8 +542,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -539,8 +542,7 @@ class Chi2SF(BinaryScalarOp):
/*--------------------------------------------------------------------*/ /*--------------------------------------------------------------------*/
DEVICE double GammaP (double n, double x)
double GammaP (double n, double x)
{ /* --- regularized Gamma function P */ { /* --- regularized Gamma function P */
assert((n > 0) && (x >= 0)); /* check the function arguments */ assert((n > 0) && (x >= 0)); /* check the function arguments */
if (x <= 0) return 0; /* treat x = 0 as a special case */ if (x <= 0) return 0; /* treat x = 0 as a special case */
...@@ -555,21 +557,21 @@ class Chi2SF(BinaryScalarOp): ...@@ -555,21 +557,21 @@ class Chi2SF(BinaryScalarOp):
{ {
return 1 - GammaP(k/2., x/2.); return 1 - GammaP(k/2., x/2.);
} }
#endif
""") """)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, k = inp x, k = inp
z, = out z, = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = z.dtype dtype = 'npy_' + node.outputs[0].dtype
return """%(z)s = return """%(z)s =
(%(dtype)s)Chi2SF(%(k)s, %(x)s);""" % locals() (%(dtype)s)Chi2SF(%(k)s, %(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论