提交 08d16c0a authored 作者: Ethan Buchman's avatar Ethan Buchman

made chi2sf binary, fixed reviewer comments

上级 e5de51da
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#as scipy is not always available, we treat them separatly #as scipy is not always available, we treat them separatly
import numpy import numpy
from theano.scalar.basic import (UnaryScalarOp, from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp,
exp, upgrade_to_float, exp, upgrade_to_float,
float_types) float_types)
from theano.scalar.basic import (upgrade_to_float_no_complex, from theano.scalar.basic import (upgrade_to_float_no_complex,
...@@ -14,8 +14,6 @@ try: ...@@ -14,8 +14,6 @@ try:
import scipy.special import scipy.special
import scipy.stats import scipy.stats
imported_scipy_special = True imported_scipy_special = True
# note: the Erf functions check for this but others do not. Should they?
# Importing scipy.special may raise ValueError. # Importing scipy.special may raise ValueError.
# See http://projects.scipy.org/scipy/ticket/1739 # See http://projects.scipy.org/scipy/ticket/1739
except (ImportError, ValueError): except (ImportError, ValueError):
...@@ -155,7 +153,10 @@ class Gamma(UnaryScalarOp): ...@@ -155,7 +153,10 @@ class Gamma(UnaryScalarOp):
return scipy.special.gamma(x) return scipy.special.gamma(x)
def impl(self, x): def impl(self, x):
if imported_scipy_special:
return Gamma.st_impl(x) return Gamma.st_impl(x)
else:
super(Gamma, self).impl(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * gamma(x) * psi(x), return gz * gamma(x) * psi(x),
...@@ -182,8 +183,10 @@ class GammaLn(UnaryScalarOp): ...@@ -182,8 +183,10 @@ class GammaLn(UnaryScalarOp):
return scipy.special.gammaln(x) return scipy.special.gammaln(x)
def impl(self, x): def impl(self, x):
if imported_scipy_special:
return GammaLn.st_impl(x) return GammaLn.st_impl(x)
else:
super(GammaLn, self).impl(x)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
gz, = grads gz, = grads
...@@ -214,7 +217,10 @@ class Psi(UnaryScalarOp): ...@@ -214,7 +217,10 @@ class Psi(UnaryScalarOp):
return scipy.special.psi(x) return scipy.special.psi(x)
def impl(self, x): def impl(self, x):
if imported_scipy_special:
return Psi.st_impl(x) return Psi.st_impl(x)
else:
super(Psi, self).impl(x)
def grad(self, inputs, outputs_gradients): def grad(self, inputs, outputs_gradients):
raise NotImplementedError() raise NotImplementedError()
...@@ -283,8 +289,7 @@ DEVICE double _psi(double x){ ...@@ -283,8 +289,7 @@ DEVICE double _psi(double x){
return hash(type(self)) return hash(type(self))
psi = Psi(upgrade_to_float, name='psi') psi = Psi(upgrade_to_float, name='psi')
from theano.scalar import ScalarOp #necessary since chi2sf takes two inputs class Chi2SF(BinaryScalarOp):
class Chi2SF(ScalarOp):
""" """
Compute (1 - chi2_cdf(x)) Compute (1 - chi2_cdf(x))
ie. chi2 pvalue (chi2 'survival function') ie. chi2 pvalue (chi2 'survival function')
...@@ -294,11 +299,23 @@ class Chi2SF(ScalarOp): ...@@ -294,11 +299,23 @@ class Chi2SF(ScalarOp):
def st_impl(x, k): def st_impl(x, k):
return scipy.stats.chi2.sf(x, k) return scipy.stats.chi2.sf(x, k)
def impl(self, x, k): def impl(self, x, k):
if imported_scipy_special:
return Chi2SF.st_impl(x, k) return Chi2SF.st_impl(x, k)
else:
super(Chi2SF, self).impl(x, k)
def c_support_code(self): def c_support_code(self):
return( return(
""" """
//For GPU support
#ifdef __CUDACC__
#define DEVICE __device__
#else
#define DEVICE
#endif
#ifndef _CHI2FUNCDEFINED
#define _CHI2FUNCDEFINED
/*---------------------------------------------------------------------- /*----------------------------------------------------------------------
File : gamma.c File : gamma.c
Contents: computation of the (incomplete/regularized) gamma function Contents: computation of the (incomplete/regularized) gamma function
...@@ -534,7 +551,7 @@ class Chi2SF(ScalarOp): ...@@ -534,7 +551,7 @@ class Chi2SF(ScalarOp):
//ebuchman: this function is equivalent to scipy.stats.chi2.sf //ebuchman: this function is equivalent to scipy.stats.chi2.sf
//it's the pvalue (survival function) of a chi2 distribution //it's the pvalue (survival function) of a chi2 distribution
double Chi2SF (double k, double x) DEVICE double Chi2SF (double k, double x)
{ {
return 1 - GammaP(k/2., x/2.); return 1 - GammaP(k/2., x/2.);
} }
...@@ -546,8 +563,9 @@ class Chi2SF(ScalarOp): ...@@ -546,8 +563,9 @@ class Chi2SF(ScalarOp):
x, k = inp x, k = inp
z, = out z, = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = z.dtype
return """%(z)s = return """%(z)s =
Chi2SF(%(k)s, %(x)s);""" % locals() (%(dtype)s)Chi2SF(%(k)s, %(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other): def __eq__(self, other):
......
...@@ -1555,7 +1555,7 @@ PsiInplaceTester = makeBroadcastTester( ...@@ -1555,7 +1555,7 @@ PsiInplaceTester = makeBroadcastTester(
#chi2sf takes two inputs, a value (x) and a degrees of freedom (k). #chi2sf takes two inputs, a value (x) and a degrees of freedom (k).
# not sure how to deal with that here... # not sure how to deal with that here...
_good_broadcast_unary_psi = dict( _good_broadcast_unary_chi2sf = dict(
normal=(rand_ranged(1, 10, (2, 3)),), normal=(rand_ranged(1, 10, (2, 3)),),
empty=(numpy.asarray([]),),) empty=(numpy.asarray([]),),)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论