提交 84fdef4f authored 作者: Frederic Bastien's avatar Frederic Bastien

Use another mechanism that make sure all instance will use the scipy nfunc.

上级 aff96770
......@@ -26,6 +26,7 @@ except (ImportError, ValueError):
class Erf(UnaryScalarOp):
nfunc_spec = ('scipy.special.erf', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erf(x)
......@@ -58,6 +59,8 @@ erf = Erf(upgrade_to_float, name='erf')
class Erfc(UnaryScalarOp):
nfunc_spec = ('scipy.special.erfc', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfc(x)
......@@ -105,6 +108,7 @@ class Erfcx(UnaryScalarOp):
running on GPU an optimization will replace it with a gpu version.
"""
nfunc_spec = ('scipy.special.erfcx', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfcx(x)
......@@ -140,6 +144,7 @@ class Erfinv(UnaryScalarOp):
(TODO) Find a C implementation of erfinv for CPU.
"""
nfunc_spec = ('scipy.special.erfinv', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfinv(x)
......@@ -173,6 +178,8 @@ erfinv = Erfinv(upgrade_to_float_no_complex, name='erfinv')
class Erfcinv(UnaryScalarOp):
nfunc_spec = ('scipy.special.erfcinv', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfcinv(x)
......@@ -206,6 +213,8 @@ erfcinv = Erfcinv(upgrade_to_float_no_complex, name='erfcinv')
class Gamma(UnaryScalarOp):
nfunc_spec = ('scipy.special.gamma', 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.gamma(x)
......@@ -243,6 +252,8 @@ class GammaLn(UnaryScalarOp):
Log gamma function.
"""
nfunc_spec = ('scipy.special.gammaln', 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.gammaln(x)
......@@ -287,6 +298,8 @@ class Psi(UnaryScalarOp):
Derivative of log gamma function.
"""
nfunc_spec = ('scipy.special.psi', 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.psi(x)
......@@ -472,6 +485,7 @@ class Chi2SF(BinaryScalarOp):
https://github.com/Theano/Theano_lgpl.git
"""
nfunc_spec = ('scipy.stats.chi2.sf', 2, 1)
@staticmethod
def st_impl(x, k):
......@@ -489,6 +503,7 @@ class Jv(BinaryScalarOp):
"""
Bessel function of the first kind of order v (real).
"""
nfunc_spec = ('scipy.special.jv', 2, 1)
@staticmethod
def st_impl(v, x):
......@@ -513,6 +528,7 @@ class J1(UnaryScalarOp):
"""
Bessel function of the first kind of order 1.
"""
nfunc_spec = ('scipy.special.j1', 1, 1)
@staticmethod
def st_impl(x):
......@@ -544,6 +560,7 @@ class J0(UnaryScalarOp):
"""
Bessel function of the first kind of order 0.
"""
nfunc_spec = ('scipy.special.j0', 1, 1)
@staticmethod
def st_impl(x):
......@@ -575,6 +592,7 @@ class Iv(BinaryScalarOp):
"""
Modified Bessel function of the first kind of order v (real).
"""
nfunc_spec = ('scipy.special.iv', 2, 1)
@staticmethod
def st_impl(v, x):
......@@ -599,6 +617,7 @@ class I1(UnaryScalarOp):
"""
Modified Bessel function of the first kind of order 1.
"""
nfunc_spec = ('scipy.special.i1', 1, 1)
@staticmethod
def st_impl(x):
......@@ -622,6 +641,7 @@ class I0(UnaryScalarOp):
"""
Modified Bessel function of the first kind of order 0.
"""
nfunc_spec = ('scipy.special.i0', 1, 1)
@staticmethod
def st_impl(x):
......
......@@ -2342,42 +2342,42 @@ def arctanh(a):
"""hyperbolic arc tangent of a"""
@_scal_elemwise_with_nfunc('scipy.special.erf', 1, 1)
@_scal_elemwise
def erf(a):
"""error function"""
@_scal_elemwise_with_nfunc('scipy.special.erfc', 1, 1)
@_scal_elemwise
def erfc(a):
"""complementary error function"""
@_scal_elemwise_with_nfunc('scipy.special.erfcx', 1, 1)
@_scal_elemwise
def erfcx(a):
"""scaled complementary error function"""
@_scal_elemwise_with_nfunc('scipy.special.erfinv', 1, 1)
@_scal_elemwise
def erfinv(a):
"""inverse error function"""
@_scal_elemwise_with_nfunc('scipy.special.erfcinv', 1, 1)
@_scal_elemwise
def erfcinv(a):
"""inverse complementary error function"""
@_scal_elemwise_with_nfunc('scipy.special.gamma', 1, 1)
@_scal_elemwise
def gamma(a):
"""gamma function"""
@_scal_elemwise_with_nfunc('scipy.special.gammaln', 1, 1)
@_scal_elemwise
def gammaln(a):
"""log gamma function"""
@_scal_elemwise_with_nfunc('scipy.special.psi', 1, 1)
@_scal_elemwise
def psi(a):
"""derivative of log gamma function"""
......@@ -2387,37 +2387,37 @@ def tri_gamma(a):
"""second derivative of the log gamma function"""
@_scal_elemwise_with_nfunc('scipy.stats.chi2.sf', 2, 1)
@_scal_elemwise
def chi2sf(x, k):
"""chi squared survival function"""
@_scal_elemwise_with_nfunc('scipy.special.j0', 1, 1)
@_scal_elemwise
def j0(x):
"""Bessel function of the first kind of order 0."""
@_scal_elemwise_with_nfunc('scipy.special.j1', 1, 1)
@_scal_elemwise
def j1(x):
"""Bessel function of the first kind of order 1."""
@_scal_elemwise_with_nfunc('scipy.special.jv', 2, 1)
@_scal_elemwise
def jv(v, x):
"""Bessel function of the first kind of order v (real)."""
@_scal_elemwise_with_nfunc('scipy.special.i0', 1, 1)
@_scal_elemwise
def i0(x):
"""Modified Bessel function of the first kind of order 0."""
@_scal_elemwise_with_nfunc('scipy.special.i1', 1, 1)
@_scal_elemwise
def i1(x):
"""Modified Bessel function of the first kind of order 1."""
@_scal_elemwise_with_nfunc('scipy.special.iv', 2, 1)
@_scal_elemwise
def iv(v, x):
"""Modified Bessel function of the first kind of order v (real)."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论