提交 17d4842c authored 作者: Matt Graham's avatar Matt Graham

Adding scipy.special modified Bessel function ops.

上级 8fb9d9a7
...@@ -429,3 +429,61 @@ class J0(UnaryScalarOp): ...@@ -429,3 +429,61 @@ class J0(UnaryScalarOp):
j0(%(x)s);""" % locals() j0(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented') raise NotImplementedError('only floating point is implemented')
j0 = J0(upgrade_to_float, name='j0') j0 = J0(upgrade_to_float, name='j0')
class I1(UnaryScalarOp):
"""
Modified Bessel function of order 1.
"""
@staticmethod
def st_impl(x):
return scipy.special.i1(x)
def impl(self, x):
if imported_scipy_special:
return self.st_impl(x)
else:
super(I1, self).impl(x)
def grad(self, inp, grads):
raise NotImplementedError()
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in float_types:
return """%(z)s =
i1(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
i1 = I1(upgrade_to_float, name='i1')
class I0(UnaryScalarOp):
"""
Modified Bessel function of order 0.
"""
@staticmethod
def st_impl(x):
return scipy.special.i0(x)
def impl(self, x):
if imported_scipy_special:
return self.st_impl(x)
else:
super(I0, self).impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
return [gz * i1(x)]
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in float_types:
return """%(z)s =
i0(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
i0 = I0(upgrade_to_float, name='i0')
...@@ -2309,6 +2309,16 @@ def j1(a): ...@@ -2309,6 +2309,16 @@ def j1(a):
"""Bessel function of the 1'th kind""" """Bessel function of the 1'th kind"""
@scal_elemwise
def i0(a):
"""Modified Bessel function of order 0."""
@scal_elemwise
def i1(a):
"Modified Bessel function of order 1."
@_scal_elemwise @_scal_elemwise
def real(z): def real(z):
"""Return real component of complex-valued tensor `z`""" """Return real component of complex-valued tensor `z`"""
......
...@@ -290,6 +290,16 @@ def j1_inplace(a): ...@@ -290,6 +290,16 @@ def j1_inplace(a):
"""Bessel function of the 0'th kind""" """Bessel function of the 0'th kind"""
@_scal_inplace
def i0_inplace(a):
"""Modified Bessel function of order 0."""
@_scal_inplace
def i1_inplace(a):
"Modified Bessel function of order 1."
@_scal_inplace @_scal_inplace
def second_inplace(a): def second_inplace(a):
"""Fill `a` with `b`""" """Fill `a` with `b`"""
......
...@@ -1721,6 +1721,8 @@ if imported_scipy_special: ...@@ -1721,6 +1721,8 @@ if imported_scipy_special:
expected_chi2sf = scipy.stats.chi2.sf expected_chi2sf = scipy.stats.chi2.sf
expected_j0 = scipy.special.j0 expected_j0 = scipy.special.j0
expected_j1 = scipy.special.j1 expected_j1 = scipy.special.j1
expected_i0 = scipy.special.i0
expected_i1 = scipy.special.i1
skip_scipy = False skip_scipy = False
expected_erfcx = scipy.special.erfcx expected_erfcx = scipy.special.erfcx
else: else:
...@@ -1735,6 +1737,8 @@ else: ...@@ -1735,6 +1737,8 @@ else:
expected_chi2sf = [] expected_chi2sf = []
expected_j0 = [] expected_j0 = []
expected_j1 = [] expected_j1 = []
expected_i0 = []
expected_i1 = []
skip_scipy = "scipy is not present" skip_scipy = "scipy is not present"
ErfTester = makeBroadcastTester( ErfTester = makeBroadcastTester(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论