提交 b18e0432 authored 作者: jsalvatier's avatar jsalvatier

added new scipy functions

上级 c0d9f466
...@@ -61,3 +61,101 @@ class Erfc(UnaryScalarOp): ...@@ -61,3 +61,101 @@ class Erfc(UnaryScalarOp):
# scipy.special.erfc don't support complex. Why? # scipy.special.erfc don't support complex. Why?
erfc = Erfc(upgrade_to_float_no_complex, name = 'erfc') erfc = Erfc(upgrade_to_float_no_complex, name = 'erfc')
class GammaLn(UnaryScalarOp):
"""
Log gamma function.
"""
@staticmethod
def st_impl(x):
return special.gammaln(x)
def impl(self, x):
return GammaLn.st_impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
return [gz * scalar_psi(x)]
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in [scalar.float32, scalar.float64]:
return """%(z)s =
lgamma(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
gammaln = GammaLn(upgrade_to_float, name='gammaln')
class Psi(UnaryScalarOp):
"""
Derivative of log gamma function.
"""
@staticmethod
def st_impl(x):
return special.psi(x)
def impl(self, x):
return Psi.st_impl(x)
#def grad() no gradient now
def c_support_code(self):
return (
"""
#ifndef _PSIFUNCDEFINED
#define _PSIFUNCDEFINED
double _psi(double x){
/*taken from
Bernardo, J. M. (1976). Algorithm AS 103: Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
http://www.uv.es/~bernardo/1976AppStatist.pdf */
double y, R, psi_ = 0;
double S = 1.0e-5;
double C = 8.5;
double S3 = 8.333333333e-2;
double S4 = 8.333333333e-3;
double S5 = 3.968253968e-3;
double D1 = -0.5772156649 ;
y = x;
if (y <= 0.0)
return psi_;
if (y <= S )
return D1 - 1.0/y;
while (y < C){
psi_ = psi_ - 1.0 / y;
y = y + 1;}
R = 1.0 / y;
psi_ = psi_ + log(y) - .5 * R ;
R= R*R;
psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
return psi_;}
#endif
""" )
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in [scalar.float32, scalar.float64]:
return """%(z)s =
_psi(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
psi = Psi(upgrade_to_float, name='psi')
ar
\ No newline at end of file
...@@ -2639,6 +2639,14 @@ def erf(a): ...@@ -2639,6 +2639,14 @@ def erf(a):
def erfc(a): def erfc(a):
"""complementary error function""" """complementary error function"""
@_scal_elemwise
def gammaln(a):
"""log gamma function"""
@_scal_elemwise
def psi(a):
"""derivative of log gamma function"""
@_scal_elemwise_with_nfunc('real', 1, -1) @_scal_elemwise_with_nfunc('real', 1, -1)
def real(z): def real(z):
......
...@@ -202,6 +202,14 @@ def erf_inplace(a): ...@@ -202,6 +202,14 @@ def erf_inplace(a):
@_scal_inplace @_scal_inplace
def erfc_inplace(a): def erfc_inplace(a):
"""complementary error function""" """complementary error function"""
@_scal_inplace
def gammaln_inplace(a):
"""log gamma function"""
@_scal_inplace
def psi_inplace(a):
"""derivative of log gamma function"""
@_scal_inplace @_scal_inplace
def second_inplace(a): def second_inplace(a):
......
...@@ -1234,6 +1234,9 @@ del _good_broadcast_unary_normal_no_int['integers'] ...@@ -1234,6 +1234,9 @@ del _good_broadcast_unary_normal_no_int['integers']
if imported_scipy_special: if imported_scipy_special:
expected_erf = scipy.special.erf expected_erf = scipy.special.erf
expected_erfc = scipy.special.erfc expected_erfc = scipy.special.erfc
expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi
skip_scipy = False skip_scipy = False
else: else:
expected_erf = [] expected_erf = []
...@@ -1272,6 +1275,38 @@ ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace, ...@@ -1272,6 +1275,38 @@ ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace,
inplace = True, inplace = True,
skip = skip_scipy) skip = skip_scipy)
GammaLnTester = makeBroadcastTester(op = tensor.gammaln,
expected = expected_gammaln,
good = _good_broadcast_unary_normal_no_int_no_complex,
grad = _grad_broadcast_unary_normal,
eps = 2e-10,
mode = mode_no_scipy,
skip = skip_scipy)
GammaLnInplaceTester = makeBroadcastTester(op = inplace.gammaln_inplace,
expected = expected_erfc,
good = _good_broadcast_unary_normal_no_int_no_complex,
grad = _grad_broadcast_unary_normal,
eps = 2e-10,
mode = mode_no_scipy,
inplace = True,
skip = skip_scipy)
PsiTester = makeBroadcastTester(op = tensor.psi,
expected = expected_psi,
good = _good_broadcast_unary_normal_no_int_no_complex,
grad = _grad_broadcast_unary_normal,
eps = 2e-10,
mode = mode_no_scipy,
skip = skip_scipy)
PsiInplaceTester = makeBroadcastTester(op = inplace.psi_inplace,
expected = expected_psi,
good = _good_broadcast_unary_normal_no_int_no_complex,
grad = _grad_broadcast_unary_normal,
eps = 2e-10,
mode = mode_no_scipy,
inplace = True,
skip = skip_scipy)
ZerosLikeTester = makeBroadcastTester( ZerosLikeTester = makeBroadcastTester(
op=tensor.zeros_like, op=tensor.zeros_like,
expected=numpy.zeros_like, expected=numpy.zeros_like,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论