提交 2cef9c0e authored 作者: David Horsley's avatar David Horsley 提交者: Ricardo Vieira

add Exponentially scaled modified Bessel function

Fixes #542
上级 68b41a48
......@@ -27,6 +27,7 @@ from pytensor.scalar.math import (
Erfcx,
Erfinv,
Iv,
Ive,
Log1mexp,
Psi,
TriGamma,
......@@ -267,6 +268,13 @@ def jax_funcify_Iv(op, **kwargs):
return iv
@jax_funcify.register(Ive)
def jax_funcify_Ive(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
return ive
@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
......
......@@ -1197,6 +1197,37 @@ class I0(UnaryScalarOp):
i0 = I0(upgrade_to_float, name="i0")
class Ive(BinaryScalarOp):
"""
Exponentially scaled modified Bessel function of the first kind of order v (real).
"""
nfunc_spec = ("scipy.special.ive", 2, 1)
@staticmethod
def st_impl(v, x):
return scipy.special.ive(v, x)
def impl(self, v, x):
return self.st_impl(v, x)
def grad(self, inputs, grads):
v, x = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, v),
gz
* (ive(v - 1, x) - 2.0 * _unsafe_sign(x) * ive(v, x) + ive(v + 1, x))
/ 2.0,
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
ive = Ive(upgrade_to_float, name="ive")
class Sigmoid(UnaryScalarOp):
"""
Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit
......
......@@ -313,6 +313,11 @@ def iv_inplace(v, x):
"""Modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def ive_inplace(v, x):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def sigmoid_inplace(x):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
......
......@@ -1435,6 +1435,11 @@ def iv(v, x):
"""Modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def ive(v, x):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def sigmoid(x):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
......@@ -3039,6 +3044,7 @@ __all__ = [
"i0",
"i1",
"iv",
"ive",
"sigmoid",
"expit",
"softplus",
......
......@@ -75,6 +75,7 @@ expected_jv = scipy.special.jv
expected_i0 = scipy.special.i0
expected_i1 = scipy.special.i1
expected_iv = scipy.special.iv
expected_ive = scipy.special.ive
expected_erfcx = scipy.special.erfcx
expected_sigmoid = scipy.special.expit
expected_hyp2f1 = scipy.special.hyp2f1
......@@ -639,6 +640,23 @@ TestIvInplaceBroadcast = makeBroadcastTester(
inplace=True,
)
TestIveBroadcast = makeBroadcastTester(
op=at.ive,
expected=expected_ive,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
)
TestIveInplaceBroadcast = makeBroadcastTester(
op=inplace.ive_inplace,
expected=expected_ive,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
def test_verify_iv_grad():
# Verify Iv gradient.
......@@ -652,6 +670,18 @@ def test_verify_iv_grad():
utt.verify_grad(fixed_first_input_iv, [x_val])
def test_verify_ive_grad():
# Verify Ive gradient.
# Implemented separately due to need to fix first input for which grad is
# not defined.
v_val, x_val = _grad_broadcast_binary_bessel["normal"]
def fixed_first_input_ive(x):
return at.ive(v_val, x)
utt.verify_grad(fixed_first_input_ive, [x_val])
TestSigmoidBroadcast = makeBroadcastTester(
op=at.sigmoid,
expected=expected_sigmoid,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论