提交 82284b21 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6396 from nouiz/scipy_nfunc

Scipy nfunc
......@@ -26,6 +26,8 @@ except (ImportError, ValueError):
class Erf(UnaryScalarOp):
nfunc_spec = ('scipy.special.erf', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erf(x)
......@@ -58,6 +60,8 @@ erf = Erf(upgrade_to_float, name='erf')
class Erfc(UnaryScalarOp):
nfunc_spec = ('scipy.special.erfc', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfc(x)
......@@ -105,6 +109,8 @@ class Erfcx(UnaryScalarOp):
running on GPU an optimization will replace it with a gpu version.
"""
nfunc_spec = ('scipy.special.erfcx', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfcx(x)
......@@ -140,6 +146,8 @@ class Erfinv(UnaryScalarOp):
(TODO) Find a C implementation of erfinv for CPU.
"""
nfunc_spec = ('scipy.special.erfinv', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfinv(x)
......@@ -173,6 +181,8 @@ erfinv = Erfinv(upgrade_to_float_no_complex, name='erfinv')
class Erfcinv(UnaryScalarOp):
nfunc_spec = ('scipy.special.erfcinv', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfcinv(x)
......@@ -206,6 +216,8 @@ erfcinv = Erfcinv(upgrade_to_float_no_complex, name='erfcinv')
class Gamma(UnaryScalarOp):
nfunc_spec = ('scipy.special.gamma', 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.gamma(x)
......@@ -243,6 +255,8 @@ class GammaLn(UnaryScalarOp):
Log gamma function.
"""
nfunc_spec = ('scipy.special.gammaln', 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.gammaln(x)
......@@ -287,6 +301,8 @@ class Psi(UnaryScalarOp):
Derivative of log gamma function.
"""
nfunc_spec = ('scipy.special.psi', 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.psi(x)
......@@ -472,6 +488,7 @@ class Chi2SF(BinaryScalarOp):
https://github.com/Theano/Theano_lgpl.git
"""
nfunc_spec = ('scipy.stats.chi2.sf', 2, 1)
@staticmethod
def st_impl(x, k):
......@@ -489,6 +506,7 @@ class Jv(BinaryScalarOp):
"""
Bessel function of the first kind of order v (real).
"""
nfunc_spec = ('scipy.special.jv', 2, 1)
@staticmethod
def st_impl(v, x):
......@@ -513,6 +531,7 @@ class J1(UnaryScalarOp):
"""
Bessel function of the first kind of order 1.
"""
nfunc_spec = ('scipy.special.j1', 1, 1)
@staticmethod
def st_impl(x):
......@@ -544,6 +563,7 @@ class J0(UnaryScalarOp):
"""
Bessel function of the first kind of order 0.
"""
nfunc_spec = ('scipy.special.j0', 1, 1)
@staticmethod
def st_impl(x):
......@@ -575,6 +595,7 @@ class Iv(BinaryScalarOp):
"""
Modified Bessel function of the first kind of order v (real).
"""
nfunc_spec = ('scipy.special.iv', 2, 1)
@staticmethod
def st_impl(v, x):
......@@ -599,6 +620,7 @@ class I1(UnaryScalarOp):
"""
Modified Bessel function of the first kind of order 1.
"""
nfunc_spec = ('scipy.special.i1', 1, 1)
@staticmethod
def st_impl(x):
......@@ -622,6 +644,7 @@ class I0(UnaryScalarOp):
"""
Modified Bessel function of the first kind of order 0.
"""
nfunc_spec = ('scipy.special.i0', 1, 1)
@staticmethod
def st_impl(x):
......
......@@ -392,17 +392,13 @@ second dimension
inplace_pattern = frozendict({})
self.name = name
self.scalar_op = scalar_op
self.inplace_pattern = frozendict(inplace_pattern)
self.inplace_pattern = inplace_pattern
self.destroy_map = dict((o, [i]) for o, i in self.inplace_pattern.items())
self.ufunc = None
self.nfunc = None
if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, 'nfunc_spec', None)
self.nfunc_spec = nfunc_spec
if nfunc_spec:
self.nfunc = getattr(np, nfunc_spec[0])
self.__setstate__(self.__dict__)
super(Elemwise, self).__init__(openmp=openmp)
def __getstate__(self):
......@@ -417,12 +413,6 @@ second dimension
self.ufunc = None
self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern)
if getattr(self, 'nfunc_spec', None):
self.nfunc = getattr(np, self.nfunc_spec[0])
elif 0 < self.scalar_op.nin < 32:
self.ufunc = np.frompyfunc(self.scalar_op.impl,
self.scalar_op.nin,
self.scalar_op.nout)
def get_output_info(self, dim_shuffle, *inputs):
"""Return the outputs dtype and broadcastable pattern and the
......@@ -655,9 +645,28 @@ second dimension
return ret
def prepare_node(self, node, storage_map, compute_map, impl):
# Postpone the ufunc building to the last minutes
# NumPy ufunc support only up to 31 inputs.
# Postpone the ufunc building to the last minutes due to:
# - NumPy ufunc support only up to 31 inputs.
# But our c code support more.
# - nfunc is reused for scipy and scipy is optional
if getattr(self, 'nfunc_spec', None):
self.nfunc = getattr(np, self.nfunc_spec[0], None)
if self.nfunc is None:
# Not inside NumPy. So probably another package like scipy.
symb = self.nfunc_spec[0].split(".")
for idx in range(1, len(self.nfunc_spec[0])):
try:
module = __import__('.'.join(symb[:idx]))
except ImportError:
break
for sub in symb[1:]:
try:
module = getattr(module, sub)
except AttributeError:
module = None
break
self.nfunc = module
if (len(node.inputs) < 32 and
(self.nfunc is None or
self.scalar_op.nin != len(node.inputs)) and
......@@ -743,6 +752,10 @@ second dimension
ufunc_args = inputs
ufunc_kwargs = {}
# We supported in the past calling manually op.perform.
# To keep that support we need to sometimes call self.prepare_node
if self.nfunc is None and self.ufunc is None:
self.prepare_node(node, None, None, 'py')
if self.nfunc and len(inputs) == self.nfunc_spec[1]:
ufunc = self.nfunc
nout = self.nfunc_spec[2]
......
......@@ -1782,14 +1782,16 @@ ErfcxTester = makeBroadcastTester(
good=_good_broadcast_unary_normal_float_no_complex_small_neg_range,
grad=_grad_broadcast_unary_normal_small_neg_range,
eps=2e-10,
mode=mode_no_scipy)
mode=mode_no_scipy,
skip=skip_scipy)
ErfcxInplaceTester = makeBroadcastTester(
op=inplace.erfcx_inplace,
expected=expected_erfcx,
good=_good_broadcast_unary_normal_float_no_complex_small_neg_range,
eps=2e-10,
mode=mode_no_scipy,
inplace=True)
inplace=True,
skip=skip_scipy)
ErfinvTester = makeBroadcastTester(
op=tensor.erfinv,
......@@ -2015,7 +2017,8 @@ def test_verify_jv_grad():
# Verify Jv gradient.
# Implemented separately due to need to fix first input for which grad is
# not defined.
if skip_scipy:
raise SkipTest("SciPy needed")
v_val, x_val = _grad_broadcast_binary_bessel['normal']
def fixed_first_input_jv(x):
......@@ -2082,6 +2085,8 @@ def test_verify_iv_grad():
# Verify Iv gradient.
# Implemented separately due to need to fix first input for which grad is
# not defined.
if skip_scipy:
raise SkipTest("SciPy needed")
v_val, x_val = _grad_broadcast_binary_bessel['normal']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论