提交 08d16c0a authored 作者: Ethan Buchman's avatar Ethan Buchman

made chi2sf binary, fixed reviewer comments

上级 e5de51da
......@@ -2,7 +2,7 @@
#as scipy is not always available, we treat them separatly
import numpy
from theano.scalar.basic import (UnaryScalarOp,
from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp,
exp, upgrade_to_float,
float_types)
from theano.scalar.basic import (upgrade_to_float_no_complex,
......@@ -14,8 +14,6 @@ try:
import scipy.special
import scipy.stats
imported_scipy_special = True
# note: the Erf functions check for this but others do not. Should they?
# Importing scipy.special may raise ValueError.
# See http://projects.scipy.org/scipy/ticket/1739
except (ImportError, ValueError):
......@@ -155,7 +153,10 @@ class Gamma(UnaryScalarOp):
return scipy.special.gamma(x)
def impl(self, x):
return Gamma.st_impl(x)
if imported_scipy_special:
return Gamma.st_impl(x)
else:
super(Gamma, self).impl(x)
def grad(self, (x, ), (gz, )):
return gz * gamma(x) * psi(x),
......@@ -182,8 +183,10 @@ class GammaLn(UnaryScalarOp):
return scipy.special.gammaln(x)
def impl(self, x):
return GammaLn.st_impl(x)
if imported_scipy_special:
return GammaLn.st_impl(x)
else:
super(GammaLn, self).impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
......@@ -214,7 +217,10 @@ class Psi(UnaryScalarOp):
return scipy.special.psi(x)
def impl(self, x):
return Psi.st_impl(x)
if imported_scipy_special:
return Psi.st_impl(x)
else:
super(Psi, self).impl(x)
def grad(self, inputs, outputs_gradients):
raise NotImplementedError()
......@@ -283,8 +289,7 @@ DEVICE double _psi(double x){
return hash(type(self))
psi = Psi(upgrade_to_float, name='psi')
from theano.scalar import ScalarOp #necessary since chi2sf takes two inputs
class Chi2SF(ScalarOp):
class Chi2SF(BinaryScalarOp):
"""
Compute (1 - chi2_cdf(x))
ie. chi2 pvalue (chi2 'survival function')
......@@ -294,11 +299,23 @@ class Chi2SF(ScalarOp):
def st_impl(x, k):
return scipy.stats.chi2.sf(x, k)
def impl(self, x, k):
return Chi2SF.st_impl(x, k)
if imported_scipy_special:
return Chi2SF.st_impl(x, k)
else:
super(Chi2SF, self).impl(x, k)
def c_support_code(self):
return(
"""
//For GPU support
#ifdef __CUDACC__
#define DEVICE __device__
#else
#define DEVICE
#endif
#ifndef _CHI2FUNCDEFINED
#define _CHI2FUNCDEFINED
/*----------------------------------------------------------------------
File : gamma.c
Contents: computation of the (incomplete/regularized) gamma function
......@@ -534,7 +551,7 @@ class Chi2SF(ScalarOp):
//ebuchman: this function is equivalent to scipy.stats.chi2.sf
//it's the pvalue (survival function) of a chi2 distribution
double Chi2SF (double k, double x)
DEVICE double Chi2SF (double k, double x)
{
return 1 - GammaP(k/2., x/2.);
}
......@@ -546,8 +563,9 @@ class Chi2SF(ScalarOp):
x, k = inp
z, = out
if node.inputs[0].type in float_types:
dtype = z.dtype
return """%(z)s =
Chi2SF(%(k)s, %(x)s);""" % locals()
(%(dtype)s)Chi2SF(%(k)s, %(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
......
......@@ -1555,7 +1555,7 @@ PsiInplaceTester = makeBroadcastTester(
#chi2sf takes two inputs, a value (x) and a degrees of freedom (k).
# not sure how to deal with that here...
_good_broadcast_unary_psi = dict(
_good_broadcast_unary_chi2sf = dict(
normal=(rand_ranged(1, 10, (2, 3)),),
empty=(numpy.asarray([]),),)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论