提交 3770c2a7 authored 作者: Aleksandar Botev's avatar Aleksandar Botev

Added the tri-gamma function.

上级 d9263f62
...@@ -298,8 +298,18 @@ class Psi(UnaryScalarOp): ...@@ -298,8 +298,18 @@ class Psi(UnaryScalarOp):
else: else:
super(Psi, self).impl(x) super(Psi, self).impl(x)
def grad(self, inputs, outputs_gradients): def L_op(self, inputs, outputs, grads):
raise NotImplementedError() x, = inputs
gz, = grads
if x.type in complex_types:
raise NotImplementedError()
if outputs[0].type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=theano.config.floatX)]
else:
return [x.zeros_like()]
return [gz * tri_gamma(x)]
def c_support_code(self): def c_support_code(self):
return ( return (
...@@ -365,6 +375,94 @@ class Psi(UnaryScalarOp): ...@@ -365,6 +375,94 @@ class Psi(UnaryScalarOp):
psi = Psi(upgrade_to_float, name='psi') psi = Psi(upgrade_to_float, name='psi')
class TriGamma(UnaryScalarOp):
"""
Second derivative of log gamma function.
"""
@staticmethod
def st_impl(x):
return scipy.special.polygamma(1, x)
def impl(self, x):
if imported_scipy_special:
return TriGamma.st_impl(x)
else:
super(TriGamma, self).impl(x)
def grad(self, inputs, outputs_gradients):
raise NotImplementedError()
def c_support_code(self):
# The implementation has been copied from
# http://people.sc.fsu.edu/~jburkardt/cpp_src/asa121/asa121.html
return (
"""
// For GPU support
#ifdef WITHIN_KERNEL
#define DEVICE WITHIN_KERNEL
#else
#define DEVICE
#endif
#ifndef ga_double
#define ga_double double
#endif
#ifndef _TRIGAMMAFUNCDEFINED
#define _TRIGAMMAFUNCDEFINED
DEVICE double _tri_gamma(ga_double x) {
double a = 0.0001;
double b = 5.0;
double b2 = 0.1666666667;
double b4 = -0.03333333333;
double b6 = 0.02380952381;
double b8 = -0.03333333333;
double value;
double y;
double z;
if (x <= 0) {
return 0.0;
}
if ( x <= a ) {
value = 1.0 / x / x;
return value;
}
value = 0.0;
z = x;
while ( z < b ) {
value += 1.0 / z / z;
z += 1.0;
}
y = 1.0 / z / z;
value += 0.5 * y + (1.0 + y * (b2 + y * (b4 + y * (b6 + y * b8 )))) / z;
return value;
}
#endif
""")
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in float_types:
return """%(z)s =
_tri_gamma(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
tri_gamma = TriGamma(upgrade_to_float, name='tri_gamma')
class Chi2SF(BinaryScalarOp): class Chi2SF(BinaryScalarOp):
""" """
Compute (1 - chi2_cdf(x)) ie. chi2 pvalue (chi2 'survival function'). Compute (1 - chi2_cdf(x)) ie. chi2 pvalue (chi2 'survival function').
......
...@@ -2253,6 +2253,11 @@ def psi(a): ...@@ -2253,6 +2253,11 @@ def psi(a):
"""derivative of log gamma function""" """derivative of log gamma function"""
@_scal_elemwise
def tri_gamma(a):
"""second derivative of the log gamma function"""
@_scal_elemwise @_scal_elemwise
def chi2sf(x, k): def chi2sf(x, k):
"""chi squared survival function""" """chi squared survival function"""
......
...@@ -275,6 +275,11 @@ def psi_inplace(a): ...@@ -275,6 +275,11 @@ def psi_inplace(a):
"""derivative of log gamma function""" """derivative of log gamma function"""
@_scal_inplace
def tri_gamma_inplace(a):
"""second derivative of the log gamma function"""
@_scal_inplace @_scal_inplace
def chi2sf_inplace(x, k): def chi2sf_inplace(x, k):
"""chi squared survival function""" """chi squared survival function"""
......
...@@ -1706,6 +1706,7 @@ if imported_scipy_special: ...@@ -1706,6 +1706,7 @@ if imported_scipy_special:
expected_gamma = scipy.special.gamma expected_gamma = scipy.special.gamma
expected_gammaln = scipy.special.gammaln expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi expected_psi = scipy.special.psi
expected_tri_gamma = partial(scipy.special.polygamma, 1)
expected_chi2sf = scipy.stats.chi2.sf expected_chi2sf = scipy.stats.chi2.sf
expected_j0 = scipy.special.j0 expected_j0 = scipy.special.j0
expected_j1 = scipy.special.j1 expected_j1 = scipy.special.j1
...@@ -1870,6 +1871,23 @@ PsiInplaceTester = makeBroadcastTester( ...@@ -1870,6 +1871,23 @@ PsiInplaceTester = makeBroadcastTester(
inplace=True, inplace=True,
skip=skip_scipy) skip=skip_scipy)
_good_broadcast_unary_tri_gamma = _good_broadcast_unary_psi
TriGammaTester = makeBroadcastTester(
op=tensor.tri_gamma,
expected=expected_tri_gamma,
good=_good_broadcast_unary_psi,
eps=2e-8,
mode=mode_no_scipy,
skip=skip_scipy)
TriGammaInplaceTester = makeBroadcastTester(
op=inplace.tri_gamma_inplace,
expected=expected_tri_gamma,
good=_good_broadcast_unary_tri_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy)
# chi2sf takes two inputs, a value (x) and a degrees of freedom (k). # chi2sf takes two inputs, a value (x) and a degrees of freedom (k).
# not sure how to deal with that here... # not sure how to deal with that here...
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论