提交 133abe80 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement Kve Op and Kv helper

上级 3523bfa5
......@@ -31,6 +31,7 @@ from pytensor.scalar.math import (
GammaIncInv,
Iv,
Ive,
Kve,
Log1mexp,
Psi,
TriGamma,
......@@ -288,9 +289,12 @@ def jax_funcify_Iv(op, **kwargs):
@jax_funcify.register(Ive)
def jax_funcify_Ive(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
return try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
return ive
@jax_funcify.register(Kve)
def jax_funcify_Kve(op, **kwargs):
return try_import_tfp_jax_op(op, jax_op_name="bessel_kve")
@jax_funcify.register(Log1mexp)
......
......@@ -1281,6 +1281,38 @@ class Ive(BinaryScalarOp):
ive = Ive(upgrade_to_float, name="ive")
class Kve(BinaryScalarOp):
"""Exponentially scaled modified Bessel function of the second kind of real order v."""
nfunc_spec = ("scipy.special.kve", 2, 1)
@staticmethod
def st_impl(v, x):
return scipy.special.kve(v, x)
def impl(self, v, x):
return self.st_impl(v, x)
def L_op(self, inputs, outputs, output_grads):
v, x = inputs
[kve_vx] = outputs
[g_out] = output_grads
# (1 -v/x) * kve(v, x) - kve(v - 1, x)
kve_vm1x = self(v - 1, x)
dx = (1 - v / x) * kve_vx - kve_vm1x
return [
grad_not_implemented(self, 0, v),
g_out * dx,
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
kve = Kve(upgrade_to_float, name="kve")
class Sigmoid(UnaryScalarOp):
"""
Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit
......
......@@ -1229,6 +1229,16 @@ def ive(v, x):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def kve(v, x):
"""Exponentially scaled modified Bessel function of the second kind of real order v."""
def kv(v, x):
"""Modified Bessel function of the second kind of real order v."""
return kve(v, x) * exp(-x)
@scalar_elemwise
def sigmoid(x):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
......@@ -3040,6 +3050,8 @@ __all__ = [
"i1",
"iv",
"ive",
"kv",
"kve",
"sigmoid",
"expit",
"softplus",
......
......@@ -21,6 +21,7 @@ from pytensor.tensor.math import (
gammainccinv,
gammaincinv,
iv,
kve,
log,
log1mexp,
polygamma,
......@@ -157,6 +158,7 @@ def test_erfinv():
(erfcx, (0.7,)),
(erfcinv, (0.7,)),
(iv, (0.3, 0.7)),
(kve, (-2.5, 2.0)),
],
)
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
......
......@@ -3,7 +3,7 @@ import warnings
import numpy as np
import pytest
from pytensor.gradient import verify_grad
from pytensor.gradient import NullTypeGradError, verify_grad
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import Elemwise
......@@ -18,7 +18,7 @@ from pytensor import function, grad
from pytensor import tensor as pt
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor import gammaincc, inplace, vector
from pytensor.tensor import gammaincc, inplace, kv, kve, vector
from tests import unittest_tools as utt
from tests.tensor.utils import (
_good_broadcast_unary_chi2sf,
......@@ -1196,3 +1196,37 @@ class TestHyp2F1Grad:
[dd for i, dd in enumerate(expected_dds) if i in wrt],
rtol=rtol,
)
def test_kve():
rng = np.random.default_rng(3772)
v = vector("v")
x = vector("x")
out = kve(v[:, None], x[None, :])
test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype)
test_x = np.linspace(0, 1005, 10, dtype=x.type.dtype)
np.testing.assert_allclose(
out.eval({v: test_v, x: test_x}),
scipy.special.kve(test_v[:, None], test_x[None, :]),
)
with pytest.raises(NullTypeGradError):
grad(out.sum(), v)
verify_grad(lambda x: kv(4.5, x), [test_x + 0.5], rng=rng)
def test_kv():
v = vector("v")
x = vector("x")
out = kv(v[:, None], x[None, :])
test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype)
test_x = np.linspace(0, 512, 10, dtype=x.type.dtype)
np.testing.assert_allclose(
out.eval({v: test_v, x: test_x}),
scipy.special.kv(test_v[:, None], test_x[None, :]),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论