提交 68ae1a76 authored 作者: skaae's avatar skaae

add bessel of the zeroth kind

上级 71a3700f
......@@ -394,3 +394,45 @@ class Chi2SF(BinaryScalarOp):
return hash(type(self))
chi2sf = Chi2SF(upgrade_to_float, name='chi2sf')
class J1(UnaryScalarOp):
"""
Bessel function of the 1'th kind
"""
@staticmethod
def impl(x):
return scipy.special.j1(x)
def grad(self, inp, grads):
raise NotImplementedError()
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
j1 = J1(upgrade_to_float, name='j1')
class J0(UnaryScalarOp):
"""
Bessel function of the 0'th kind
"""
@staticmethod
def impl(x):
return scipy.special.j0(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
return [gz * -1 * j1(x)]
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
j0 = J0(upgrade_to_float, name='j0')
......@@ -2201,6 +2201,16 @@ def chi2sf(x, k):
"""chi squared survival function"""
@_scal_elemwise
def j0(a):
"""Bessel function of the 0'th kind"""
@_scal_elemwise
def j1(a):
"""Bessel function of the 1'th kind"""
@_scal_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
......
......@@ -284,6 +284,16 @@ def chi2sf_inplace(x, k):
"""chi squared survival function"""
@_scal_inplace
def j0_inplace(a):
"""Bessel function of the 0'th kind"""
@_scal_inplace
def j1_inplace(a):
"""Bessel function of the 0'th kind"""
@_scal_inplace
def second_inplace(a):
"""Fill `a` with `b`"""
......
......@@ -1677,6 +1677,8 @@ if imported_scipy_special:
expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi
expected_chi2sf = lambda x, df: scipy.stats.chi2.sf(x, df).astype(x.dtype)
expected_j0 = scipy.special.j0
expected_j1 = scipy.special.j1
skip_scipy = False
if LooseVersion(scipy_version) >= LooseVersion("0.12.0"):
expected_erfcx = scipy.special.erfcx
......@@ -1861,6 +1863,43 @@ Chi2SFInplaceTester = makeBroadcastTester(
skip=skip_scipy,
name='Chi2SF')
_good_broadcast_unary_j = dict(
normal=(rand_ranged(0.1, 8, (2, 3)),),)
J0Tester = makeBroadcastTester(
op=tensor.j0,
expected=expected_j0,
good=_good_broadcast_unary_j,
grad=_good_broadcast_unary_j,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
J0InplaceTester = makeBroadcastTester(
op=inplace.j0_inplace,
expected=expected_j0,
good=_good_broadcast_unary_j,
grad=_good_broadcast_unary_j,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy)
J1Tester = makeBroadcastTester(
op=tensor.j1,
expected=expected_j1,
good=_good_broadcast_unary_j,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
J1InplaceTester = makeBroadcastTester(
op=inplace.j1_inplace,
expected=expected_j1,
good=_good_broadcast_unary_j,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy)
ZerosLikeTester = makeBroadcastTester(
op=tensor.zeros_like,
expected=numpy.zeros_like,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论