提交 47278f72 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3883 from skaae/bessel

[WIP] add bessel0
......@@ -231,12 +231,6 @@ class Gamma(UnaryScalarOp):
if node.inputs[0].type in float_types:
return """%(z)s = tgamma(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
gamma = Gamma(upgrade_to_float, name='gamma')
......@@ -275,12 +269,6 @@ class GammaLn(UnaryScalarOp):
return """%(z)s =
lgamma(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
gammaln = GammaLn(upgrade_to_float, name='gammaln')
......@@ -357,12 +345,6 @@ class Psi(UnaryScalarOp):
return """%(z)s =
_psi(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
psi = Psi(upgrade_to_float, name='psi')
......@@ -386,11 +368,62 @@ class Chi2SF(BinaryScalarOp):
return Chi2SF.st_impl(x, k)
else:
super(Chi2SF, self).impl(x, k)
chi2sf = Chi2SF(upgrade_to_float, name='chi2sf')
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
class J1(UnaryScalarOp):
"""
Bessel function of the 1'th kind
"""
chi2sf = Chi2SF(upgrade_to_float, name='chi2sf')
@staticmethod
def st_impl(x):
return scipy.special.j1(x)
def impl(self, x):
if imported_scipy_special:
return self.st_impl(x)
else:
super(J1, self).impl(x)
def grad(self, inp, grads):
raise NotImplementedError()
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in float_types:
return """%(z)s =
j1(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
j1 = J1(upgrade_to_float, name='j1')
class J0(UnaryScalarOp):
"""
Bessel function of the 0'th kind
"""
@staticmethod
def st_impl(x):
return scipy.special.j0(x)
def impl(self, x):
if imported_scipy_special:
return self.st_impl(x)
else:
super(J0, self).impl(x)
def grad(self, inp, grads):
x, = inp
gz, = grads
return [gz * -1 * j1(x)]
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in float_types:
return """%(z)s =
j0(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
j0 = J0(upgrade_to_float, name='j0')
......@@ -2202,6 +2202,16 @@ def chi2sf(x, k):
"""chi squared survival function"""
@_scal_elemwise
def j0(a):
"""Bessel function of the 0'th kind"""
@_scal_elemwise
def j1(a):
"""Bessel function of the 1'th kind"""
@_scal_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
......
......@@ -284,6 +284,16 @@ def chi2sf_inplace(x, k):
"""chi squared survival function"""
@_scal_inplace
def j0_inplace(a):
"""Bessel function of the 0'th kind"""
@_scal_inplace
def j1_inplace(a):
"""Bessel function of the 0'th kind"""
@_scal_inplace
def second_inplace(a):
"""Fill `a` with `b`"""
......
......@@ -1683,6 +1683,8 @@ if imported_scipy_special:
expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi
expected_chi2sf = lambda x, df: scipy.stats.chi2.sf(x, df).astype(x.dtype)
expected_j0 = scipy.special.j0
expected_j1 = scipy.special.j1
skip_scipy = False
if LooseVersion(scipy_version) >= LooseVersion("0.12.0"):
expected_erfcx = scipy.special.erfcx
......@@ -1700,6 +1702,8 @@ else:
expected_gammaln = []
expected_psi = []
expected_chi2sf = []
expected_j0 = []
expected_j1 = []
skip_scipy = "scipy is not present"
skip_scipy12 = "scipy is not present"
......@@ -1867,6 +1871,43 @@ Chi2SFInplaceTester = makeBroadcastTester(
skip=skip_scipy,
name='Chi2SF')
_good_broadcast_unary_j = dict(
normal=(rand_ranged(0.1, 8, (2, 3)),),)
J0Tester = makeBroadcastTester(
op=tensor.j0,
expected=expected_j0,
good=_good_broadcast_unary_j,
grad=_good_broadcast_unary_j,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
J0InplaceTester = makeBroadcastTester(
op=inplace.j0_inplace,
expected=expected_j0,
good=_good_broadcast_unary_j,
grad=_good_broadcast_unary_j,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy)
J1Tester = makeBroadcastTester(
op=tensor.j1,
expected=expected_j1,
good=_good_broadcast_unary_j,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
J1InplaceTester = makeBroadcastTester(
op=inplace.j1_inplace,
expected=expected_j1,
good=_good_broadcast_unary_j,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy)
ZerosLikeTester = makeBroadcastTester(
op=tensor.zeros_like,
expected=numpy.zeros_like,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论