提交 84fdef4f authored 作者: Frederic Bastien's avatar Frederic Bastien

Use another mechanism that make sure all instance will use the scipy nfunc.

上级 aff96770
...@@ -26,6 +26,7 @@ except (ImportError, ValueError): ...@@ -26,6 +26,7 @@ except (ImportError, ValueError):
class Erf(UnaryScalarOp): class Erf(UnaryScalarOp):
nfunc_spec = ('scipy.special.erf', 1, 1)
def impl(self, x): def impl(self, x):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erf(x) return scipy.special.erf(x)
...@@ -58,6 +59,8 @@ erf = Erf(upgrade_to_float, name='erf') ...@@ -58,6 +59,8 @@ erf = Erf(upgrade_to_float, name='erf')
class Erfc(UnaryScalarOp): class Erfc(UnaryScalarOp):
nfunc_spec = ('scipy.special.erfc', 1, 1)
def impl(self, x): def impl(self, x):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erfc(x) return scipy.special.erfc(x)
...@@ -105,6 +108,7 @@ class Erfcx(UnaryScalarOp): ...@@ -105,6 +108,7 @@ class Erfcx(UnaryScalarOp):
running on GPU an optimization will replace it with a gpu version. running on GPU an optimization will replace it with a gpu version.
""" """
nfunc_spec = ('scipy.special.erfcx', 1, 1)
def impl(self, x): def impl(self, x):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erfcx(x) return scipy.special.erfcx(x)
...@@ -140,6 +144,7 @@ class Erfinv(UnaryScalarOp): ...@@ -140,6 +144,7 @@ class Erfinv(UnaryScalarOp):
(TODO) Find a C implementation of erfinv for CPU. (TODO) Find a C implementation of erfinv for CPU.
""" """
nfunc_spec = ('scipy.special.erfinv', 1, 1)
def impl(self, x): def impl(self, x):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erfinv(x) return scipy.special.erfinv(x)
...@@ -173,6 +178,8 @@ erfinv = Erfinv(upgrade_to_float_no_complex, name='erfinv') ...@@ -173,6 +178,8 @@ erfinv = Erfinv(upgrade_to_float_no_complex, name='erfinv')
class Erfcinv(UnaryScalarOp): class Erfcinv(UnaryScalarOp):
nfunc_spec = ('scipy.special.erfcinv', 1, 1)
def impl(self, x): def impl(self, x):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erfcinv(x) return scipy.special.erfcinv(x)
...@@ -206,6 +213,8 @@ erfcinv = Erfcinv(upgrade_to_float_no_complex, name='erfcinv') ...@@ -206,6 +213,8 @@ erfcinv = Erfcinv(upgrade_to_float_no_complex, name='erfcinv')
class Gamma(UnaryScalarOp): class Gamma(UnaryScalarOp):
nfunc_spec = ('scipy.special.gamma', 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
return scipy.special.gamma(x) return scipy.special.gamma(x)
...@@ -243,6 +252,8 @@ class GammaLn(UnaryScalarOp): ...@@ -243,6 +252,8 @@ class GammaLn(UnaryScalarOp):
Log gamma function. Log gamma function.
""" """
nfunc_spec = ('scipy.special.gammaln', 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
return scipy.special.gammaln(x) return scipy.special.gammaln(x)
...@@ -287,6 +298,8 @@ class Psi(UnaryScalarOp): ...@@ -287,6 +298,8 @@ class Psi(UnaryScalarOp):
Derivative of log gamma function. Derivative of log gamma function.
""" """
nfunc_spec = ('scipy.special.psi', 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
return scipy.special.psi(x) return scipy.special.psi(x)
...@@ -472,6 +485,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -472,6 +485,7 @@ class Chi2SF(BinaryScalarOp):
https://github.com/Theano/Theano_lgpl.git https://github.com/Theano/Theano_lgpl.git
""" """
nfunc_spec = ('scipy.stats.chi2.sf', 2, 1)
@staticmethod @staticmethod
def st_impl(x, k): def st_impl(x, k):
...@@ -489,6 +503,7 @@ class Jv(BinaryScalarOp): ...@@ -489,6 +503,7 @@ class Jv(BinaryScalarOp):
""" """
Bessel function of the first kind of order v (real). Bessel function of the first kind of order v (real).
""" """
nfunc_spec = ('scipy.special.jv', 2, 1)
@staticmethod @staticmethod
def st_impl(v, x): def st_impl(v, x):
...@@ -513,6 +528,7 @@ class J1(UnaryScalarOp): ...@@ -513,6 +528,7 @@ class J1(UnaryScalarOp):
""" """
Bessel function of the first kind of order 1. Bessel function of the first kind of order 1.
""" """
nfunc_spec = ('scipy.special.j1', 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
...@@ -544,6 +560,7 @@ class J0(UnaryScalarOp): ...@@ -544,6 +560,7 @@ class J0(UnaryScalarOp):
""" """
Bessel function of the first kind of order 0. Bessel function of the first kind of order 0.
""" """
nfunc_spec = ('scipy.special.j0', 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
...@@ -575,6 +592,7 @@ class Iv(BinaryScalarOp): ...@@ -575,6 +592,7 @@ class Iv(BinaryScalarOp):
""" """
Modified Bessel function of the first kind of order v (real). Modified Bessel function of the first kind of order v (real).
""" """
nfunc_spec = ('scipy.special.iv', 2, 1)
@staticmethod @staticmethod
def st_impl(v, x): def st_impl(v, x):
...@@ -599,6 +617,7 @@ class I1(UnaryScalarOp): ...@@ -599,6 +617,7 @@ class I1(UnaryScalarOp):
""" """
Modified Bessel function of the first kind of order 1. Modified Bessel function of the first kind of order 1.
""" """
nfunc_spec = ('scipy.special.i1', 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
...@@ -622,6 +641,7 @@ class I0(UnaryScalarOp): ...@@ -622,6 +641,7 @@ class I0(UnaryScalarOp):
""" """
Modified Bessel function of the first kind of order 0. Modified Bessel function of the first kind of order 0.
""" """
nfunc_spec = ('scipy.special.i0', 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
......
...@@ -2342,42 +2342,42 @@ def arctanh(a): ...@@ -2342,42 +2342,42 @@ def arctanh(a):
"""hyperbolic arc tangent of a""" """hyperbolic arc tangent of a"""
@_scal_elemwise_with_nfunc('scipy.special.erf', 1, 1) @_scal_elemwise
def erf(a): def erf(a):
"""error function""" """error function"""
@_scal_elemwise_with_nfunc('scipy.special.erfc', 1, 1) @_scal_elemwise
def erfc(a): def erfc(a):
"""complementary error function""" """complementary error function"""
@_scal_elemwise_with_nfunc('scipy.special.erfcx', 1, 1) @_scal_elemwise
def erfcx(a): def erfcx(a):
"""scaled complementary error function""" """scaled complementary error function"""
@_scal_elemwise_with_nfunc('scipy.special.erfinv', 1, 1) @_scal_elemwise
def erfinv(a): def erfinv(a):
"""inverse error function""" """inverse error function"""
@_scal_elemwise_with_nfunc('scipy.special.erfcinv', 1, 1) @_scal_elemwise
def erfcinv(a): def erfcinv(a):
"""inverse complementary error function""" """inverse complementary error function"""
@_scal_elemwise_with_nfunc('scipy.special.gamma', 1, 1) @_scal_elemwise
def gamma(a): def gamma(a):
"""gamma function""" """gamma function"""
@_scal_elemwise_with_nfunc('scipy.special.gammaln', 1, 1) @_scal_elemwise
def gammaln(a): def gammaln(a):
"""log gamma function""" """log gamma function"""
@_scal_elemwise_with_nfunc('scipy.special.psi', 1, 1) @_scal_elemwise
def psi(a): def psi(a):
"""derivative of log gamma function""" """derivative of log gamma function"""
...@@ -2387,37 +2387,37 @@ def tri_gamma(a): ...@@ -2387,37 +2387,37 @@ def tri_gamma(a):
"""second derivative of the log gamma function""" """second derivative of the log gamma function"""
@_scal_elemwise_with_nfunc('scipy.stats.chi2.sf', 2, 1) @_scal_elemwise
def chi2sf(x, k): def chi2sf(x, k):
"""chi squared survival function""" """chi squared survival function"""
@_scal_elemwise_with_nfunc('scipy.special.j0', 1, 1) @_scal_elemwise
def j0(x): def j0(x):
"""Bessel function of the first kind of order 0.""" """Bessel function of the first kind of order 0."""
@_scal_elemwise_with_nfunc('scipy.special.j1', 1, 1) @_scal_elemwise
def j1(x): def j1(x):
"""Bessel function of the first kind of order 1.""" """Bessel function of the first kind of order 1."""
@_scal_elemwise_with_nfunc('scipy.special.jv', 2, 1) @_scal_elemwise
def jv(v, x): def jv(v, x):
"""Bessel function of the first kind of order v (real).""" """Bessel function of the first kind of order v (real)."""
@_scal_elemwise_with_nfunc('scipy.special.i0', 1, 1) @_scal_elemwise
def i0(x): def i0(x):
"""Modified Bessel function of the first kind of order 0.""" """Modified Bessel function of the first kind of order 0."""
@_scal_elemwise_with_nfunc('scipy.special.i1', 1, 1) @_scal_elemwise
def i1(x): def i1(x):
"""Modified Bessel function of the first kind of order 1.""" """Modified Bessel function of the first kind of order 1."""
@_scal_elemwise_with_nfunc('scipy.special.iv', 2, 1) @_scal_elemwise
def iv(v, x): def iv(v, x):
"""Modified Bessel function of the first kind of order v (real).""" """Modified Bessel function of the first kind of order v (real)."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论