Unverified 提交 e3f1e040 authored 作者: Alessandro Gentili's avatar Alessandro Gentili 提交者: GitHub

Remove Iv core Op in favor of Ive and add rewrite rule for log of it (#1929)

上级 4273eb87
......@@ -29,7 +29,6 @@ from pytensor.scalar.math import (
Erfinv,
GammaIncCInv,
GammaIncInv,
Iv,
Ive,
Kve,
Log1mexp,
......@@ -277,16 +276,6 @@ def jax_funcify_from_tfp(op, **kwargs):
return tfp_jax_op
@jax_funcify.register(Iv)
def jax_funcify_Iv(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
def iv(v, x):
return ive(v, x) / jnp.exp(-jnp.abs(jnp.real(x)))
return iv
@jax_funcify.register(Ive)
def jax_funcify_Ive(op, **kwargs):
return try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
......
......@@ -1073,31 +1073,6 @@ class J0(UnaryScalarOp):
j0 = J0(upgrade_to_float, name="j0")
class Iv(BinaryScalarOp):
"""
Modified Bessel function of the first kind of order v (real).
"""
nfunc_spec = ("scipy.special.iv", 2, 1)
def impl(self, v, x):
return special.iv(v, x)
def grad(self, inputs, grads):
v, x = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, v),
gz * (iv(v - 1, x) + iv(v + 1, x)) / 2.0,
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
iv = Iv(upgrade_to_float, name="iv")
class I1(UnaryScalarOp):
"""
Modified Bessel function of the first kind of order 1.
......@@ -1111,7 +1086,7 @@ class I1(UnaryScalarOp):
def grad(self, inputs, grads):
(x,) = inputs
(gz,) = grads
return [gz * (i0(x) + iv(2, x)) / 2.0]
return [gz * (i0(x) + ive(2, x) * exp(abs(x))) / 2.0]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
......
......@@ -2429,9 +2429,14 @@ def i1(x):
"""Modified Bessel function of the first kind of order 1."""
@scalar_elemwise
def iv(v, x):
"""Modified Bessel function of the first kind of order v (real)."""
"""Modified Bessel function of the first kind of order v (real).
Computed as ``ive(v, x) * exp(abs(x))`` for numerical consistency with
``ive``. For large ``x``, prefer working in log-space:
``log(iv(v, x)) == log(ive(v, x)) + abs(x)`` to avoid overflow.
"""
return ive(v, x) * exp(abs(x))
@scalar_elemwise
......
......@@ -63,6 +63,7 @@ from pytensor.tensor.math import (
ge,
int_div,
isinf,
ive,
kve,
le,
log,
......@@ -3888,3 +3889,17 @@ local_log_kv = PatternNodeRewriter(
)
register_stabilize(local_log_kv)
local_log_iv = PatternNodeRewriter(
# Rewrite log(iv(v, x)) = log(ive(v, x) * exp(abs(x))) -> log(ive(v, x)) + abs(x)
(log, (mul, (ive, "v", "x"), (exp, (pt_abs, "x")))),
(add, (log, (ive, "v", "x")), (pt_abs, "x")),
allow_multiple_clients=True,
name="local_log_iv",
# Start the rewrite from the less likely ive node
tracks=[ive],
get_nodes=get_clients_at_depth2,
)
register_stabilize(local_log_iv)
......@@ -259,8 +259,12 @@ def isinf(): ...
def isnan(): ...
@_as_xelemwise(ps.iv)
def iv(): ...
def iv(v, x):
"""Modified Bessel function of the first kind of order v (real).
Computed as ``ive(v, x) * exp(abs(x))`` for numerical consistency.
"""
return ive(v, x) * exp(abs(x))
@_as_xelemwise(ps.ive)
......
......@@ -38,7 +38,7 @@ from pytensor.link.jax.dispatch import jax_funcify
try:
pass
import tensorflow_probability.substrates.jax.math # noqa: F401
TFP_INSTALLED = True
except ModuleNotFoundError:
......
......@@ -4785,6 +4785,19 @@ def test_log_kv_stabilization():
)
def test_log_iv_stabilization():
x = pt.scalar("x")
out = log(pt.iv(4.5, x))
# Expression would overflow to inf without rewrite
mode = get_default_mode().including("stabilize")
# Reference value log(ive(4.5, 1000.0)) + 1000.0
np.testing.assert_allclose(
out.eval({x: 1000.0}, mode=mode),
995.6171788390135,
)
@pytest.mark.parametrize("shape", [(), (4, 5, 6)], ids=["scalar", "tensor"])
def test_pow_1_rewrite(shape):
x = pt.tensor("x", shape=shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论