Unverified 提交 e3f1e040 authored 作者: Alessandro Gentili's avatar Alessandro Gentili 提交者: GitHub

Remove Iv core Op in favor of Ive and add rewrite rule for log of it (#1929)

上级 4273eb87
...@@ -29,7 +29,6 @@ from pytensor.scalar.math import ( ...@@ -29,7 +29,6 @@ from pytensor.scalar.math import (
Erfinv, Erfinv,
GammaIncCInv, GammaIncCInv,
GammaIncInv, GammaIncInv,
Iv,
Ive, Ive,
Kve, Kve,
Log1mexp, Log1mexp,
...@@ -277,16 +276,6 @@ def jax_funcify_from_tfp(op, **kwargs): ...@@ -277,16 +276,6 @@ def jax_funcify_from_tfp(op, **kwargs):
return tfp_jax_op return tfp_jax_op
@jax_funcify.register(Iv)
def jax_funcify_Iv(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
def iv(v, x):
return ive(v, x) / jnp.exp(-jnp.abs(jnp.real(x)))
return iv
@jax_funcify.register(Ive) @jax_funcify.register(Ive)
def jax_funcify_Ive(op, **kwargs): def jax_funcify_Ive(op, **kwargs):
return try_import_tfp_jax_op(op, jax_op_name="bessel_ive") return try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
......
...@@ -1073,31 +1073,6 @@ class J0(UnaryScalarOp): ...@@ -1073,31 +1073,6 @@ class J0(UnaryScalarOp):
j0 = J0(upgrade_to_float, name="j0") j0 = J0(upgrade_to_float, name="j0")
class Iv(BinaryScalarOp):
"""
Modified Bessel function of the first kind of order v (real).
"""
nfunc_spec = ("scipy.special.iv", 2, 1)
def impl(self, v, x):
return special.iv(v, x)
def grad(self, inputs, grads):
v, x = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, v),
gz * (iv(v - 1, x) + iv(v + 1, x)) / 2.0,
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
iv = Iv(upgrade_to_float, name="iv")
class I1(UnaryScalarOp): class I1(UnaryScalarOp):
""" """
Modified Bessel function of the first kind of order 1. Modified Bessel function of the first kind of order 1.
...@@ -1111,7 +1086,7 @@ class I1(UnaryScalarOp): ...@@ -1111,7 +1086,7 @@ class I1(UnaryScalarOp):
def grad(self, inputs, grads): def grad(self, inputs, grads):
(x,) = inputs (x,) = inputs
(gz,) = grads (gz,) = grads
return [gz * (i0(x) + iv(2, x)) / 2.0] return [gz * (i0(x) + ive(2, x) * exp(abs(x))) / 2.0]
def c_code(self, *args, **kwargs): def c_code(self, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -2429,9 +2429,14 @@ def i1(x): ...@@ -2429,9 +2429,14 @@ def i1(x):
"""Modified Bessel function of the first kind of order 1.""" """Modified Bessel function of the first kind of order 1."""
@scalar_elemwise
def iv(v, x): def iv(v, x):
"""Modified Bessel function of the first kind of order v (real).""" """Modified Bessel function of the first kind of order v (real).
Computed as ``ive(v, x) * exp(abs(x))`` for numerical consistency with
``ive``. For large ``x``, prefer working in log-space:
``log(iv(v, x)) == log(ive(v, x)) + abs(x)`` to avoid overflow.
"""
return ive(v, x) * exp(abs(x))
@scalar_elemwise @scalar_elemwise
......
...@@ -63,6 +63,7 @@ from pytensor.tensor.math import ( ...@@ -63,6 +63,7 @@ from pytensor.tensor.math import (
ge, ge,
int_div, int_div,
isinf, isinf,
ive,
kve, kve,
le, le,
log, log,
...@@ -3888,3 +3889,17 @@ local_log_kv = PatternNodeRewriter( ...@@ -3888,3 +3889,17 @@ local_log_kv = PatternNodeRewriter(
) )
register_stabilize(local_log_kv) register_stabilize(local_log_kv)
local_log_iv = PatternNodeRewriter(
# Rewrite log(iv(v, x)) = log(ive(v, x) * exp(abs(x))) -> log(ive(v, x)) + abs(x)
(log, (mul, (ive, "v", "x"), (exp, (pt_abs, "x")))),
(add, (log, (ive, "v", "x")), (pt_abs, "x")),
allow_multiple_clients=True,
name="local_log_iv",
# Start the rewrite from the less likely ive node
tracks=[ive],
get_nodes=get_clients_at_depth2,
)
register_stabilize(local_log_iv)
...@@ -259,8 +259,12 @@ def isinf(): ... ...@@ -259,8 +259,12 @@ def isinf(): ...
def isnan(): ... def isnan(): ...
@_as_xelemwise(ps.iv) def iv(v, x):
def iv(): ... """Modified Bessel function of the first kind of order v (real).
Computed as ``ive(v, x) * exp(abs(x))`` for numerical consistency.
"""
return ive(v, x) * exp(abs(x))
@_as_xelemwise(ps.ive) @_as_xelemwise(ps.ive)
......
...@@ -38,7 +38,7 @@ from pytensor.link.jax.dispatch import jax_funcify ...@@ -38,7 +38,7 @@ from pytensor.link.jax.dispatch import jax_funcify
try: try:
pass import tensorflow_probability.substrates.jax.math # noqa: F401
TFP_INSTALLED = True TFP_INSTALLED = True
except ModuleNotFoundError: except ModuleNotFoundError:
......
...@@ -4785,6 +4785,19 @@ def test_log_kv_stabilization(): ...@@ -4785,6 +4785,19 @@ def test_log_kv_stabilization():
) )
def test_log_iv_stabilization():
x = pt.scalar("x")
out = log(pt.iv(4.5, x))
# Expression would overflow to inf without rewrite
mode = get_default_mode().including("stabilize")
# Reference value log(ive(4.5, 1000.0)) + 1000.0
np.testing.assert_allclose(
out.eval({x: 1000.0}, mode=mode),
995.6171788390135,
)
@pytest.mark.parametrize("shape", [(), (4, 5, 6)], ids=["scalar", "tensor"]) @pytest.mark.parametrize("shape", [(), (4, 5, 6)], ids=["scalar", "tensor"])
def test_pow_1_rewrite(shape): def test_pow_1_rewrite(shape):
x = pt.tensor("x", shape=shape) x = pt.tensor("x", shape=shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论