提交 13be5b67 authored 作者: Luis Mario Domenzain's avatar Luis Mario Domenzain 提交者: Thomas Wiecki

Add gamma family functions from SciPy with C code for GPU (#6648)

* Add gamma family functions from SciPy with C code for GPU The C code is MIT-licensed. * Remove commented code and bring in missing functions * Indicate scalar definitions for Elementwise extension * Upgrade to float64 as previously * Add docstrings and remove chi2sf duplicate
上级 c5c549bc
...@@ -3,6 +3,7 @@ from __future__ import absolute_import, print_function, division ...@@ -3,6 +3,7 @@ from __future__ import absolute_import, print_function, division
# from SciPy. As SciPy is not always available, we treat them separately. # from SciPy. As SciPy is not always available, we treat them separately.
import numpy as np import numpy as np
import os
import theano import theano
from theano.gradient import grad_not_implemented from theano.gradient import grad_not_implemented
...@@ -480,13 +481,8 @@ tri_gamma = TriGamma(upgrade_to_float, name='tri_gamma') ...@@ -480,13 +481,8 @@ tri_gamma = TriGamma(upgrade_to_float, name='tri_gamma')
class Chi2SF(BinaryScalarOp): class Chi2SF(BinaryScalarOp):
""" """
Compute (1 - chi2_cdf(x)) ie. chi2 pvalue (chi2 'survival function'). Compute (1 - chi2_cdf(x))
ie. chi2 pvalue (chi2 'survival function')
C code is provided in the Theano_lgpl repository.
This make it faster.
https://github.com/Theano/Theano_lgpl.git
""" """
nfunc_spec = ('scipy.stats.chi2.sf', 2, 1) nfunc_spec = ('scipy.stats.chi2.sf', 2, 1)
...@@ -499,9 +495,206 @@ class Chi2SF(BinaryScalarOp): ...@@ -499,9 +495,206 @@ class Chi2SF(BinaryScalarOp):
return Chi2SF.st_impl(x, k) return Chi2SF.st_impl(x, k)
else: else:
super(Chi2SF, self).impl(x, k) super(Chi2SF, self).impl(x, k)
def c_support_code(self):
with open(os.path.join(
os.path.dirname(__file__),
'c_code',
'gamma.c')) as f:
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub):
x, k = inp
z, = out
if node.inputs[0].type in float_types:
dtype = 'npy_' + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) 1 - GammaP(%(k)s/2., %(x)s/2.);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
chi2sf = Chi2SF(upgrade_to_float64, name='chi2sf') chi2sf = Chi2SF(upgrade_to_float64, name='chi2sf')
class GammaInc(BinaryScalarOp):
"""
Compute the regularized lower gamma function (P).
"""
nfunc_spec = ('scipy.special.gammainc', 2, 1)
@staticmethod
def st_impl(k, x):
return scipy.special.gammainc(k, x)
def impl(self, k, x):
if imported_scipy_special:
return GammaInc.st_impl(k, x)
else:
super(GammaInc, self).impl(k, x)
def c_support_code(self):
with open(os.path.join(
os.path.dirname(__file__),
'c_code',
'gamma.c')) as f:
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub):
k, x = inp
z, = out
if node.inputs[0].type in float_types:
dtype = 'npy_' + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) GammaP(%(k)s, %(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
gammainc = GammaInc(upgrade_to_float, name='gammainc')
class GammaIncC(BinaryScalarOp):
"""
Compute the regularized upper gamma function (Q).
"""
nfunc_spec = ('scipy.special.gammaincc', 2, 1)
@staticmethod
def st_impl(k, x):
return scipy.special.gammaincc(x, k)
def impl(self, k, x):
if imported_scipy_special:
return GammaIncC.st_impl(k, x)
else:
super(GammaIncC, self).impl(k, x)
def c_support_code(self):
with open(os.path.join(
os.path.dirname(__file__),
'c_code',
'gamma.c')) as f:
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub):
k, x = inp
z, = out
if node.inputs[0].type in float_types:
dtype = 'npy_' + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) gammaQ(%(k)s, %(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
gammaincc = GammaIncC(upgrade_to_float, name='gammaincc')
class GammaU(BinaryScalarOp):
"""
compute the upper incomplete gamma function.
"""
# Note there is no basic SciPy version so no nfunc_spec.
@staticmethod
def st_impl(k, x):
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
def impl(self, k, x):
if imported_scipy_special:
return GammaU.st_impl(k, x)
else:
super(GammaU, self).impl(k, x)
def c_support_code(self):
with open(os.path.join(
os.path.dirname(__file__),
'c_code',
'gamma.c')) as f:
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub):
k, x = inp
z, = out
if node.inputs[0].type in float_types:
dtype = 'npy_' + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) upperGamma(%(k)s, %(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
gammau = GammaU(upgrade_to_float, name='gammau')
class GammaL(BinaryScalarOp):
"""
Compute the lower incomplete gamma function.
"""
# Note there is no basic SciPy version so no nfunc_spec.
@staticmethod
def st_impl(k, x):
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
def impl(self, k, x):
if imported_scipy_special:
return GammaL.st_impl(k, x)
else:
super(GammaL, self).impl(k, x)
def c_support_code(self):
with open(os.path.join(
os.path.dirname(__file__),
'c_code',
'gamma.c')) as f:
raw = f.read()
return raw
def c_code(self, node, name, inp, out, sub):
k, x = inp
z, = out
if node.inputs[0].type in float_types:
dtype = 'npy_' + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) lowerGamma(%(k)s, %(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
gammal = GammaL(upgrade_to_float, name='gammal')
class Jv(BinaryScalarOp): class Jv(BinaryScalarOp):
""" """
Bessel function of the first kind of order v (real). Bessel function of the first kind of order v (real).
......
差异被折叠。
...@@ -1891,7 +1891,6 @@ def isnan(a): ...@@ -1891,7 +1891,6 @@ def isnan(a):
def isinf(a): def isinf(a):
"""isinf(a)""" """isinf(a)"""
# Rename isnan to isnan_ to allow to bypass it when not needed. # Rename isnan to isnan_ to allow to bypass it when not needed.
# glibc 2.23 don't allow isnan on int, so we remove it from the graph. # glibc 2.23 don't allow isnan on int, so we remove it from the graph.
isinf_ = isinf isinf_ = isinf
...@@ -2403,6 +2402,26 @@ def chi2sf(x, k): ...@@ -2403,6 +2402,26 @@ def chi2sf(x, k):
"""chi squared survival function""" """chi squared survival function"""
@_scal_elemwise
def gammainc(k, x):
"""Regularized lower gamma function"""
@_scal_elemwise
def gammaincc(k, x):
"""Regularized upper gamma function"""
@_scal_elemwise
def gammau(k, x):
"""Upper incomplete gamma function."""
@_scal_elemwise
def gammal(k, x):
"""Lower incomplete gamma function."""
@_scal_elemwise @_scal_elemwise
def j0(x): def j0(x):
"""Bessel function of the first kind of order 0.""" """Bessel function of the first kind of order 0."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论