提交 33a4d488 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Add stabilization rewrite for log of kv

上级 133abe80
...@@ -56,6 +56,7 @@ from pytensor.tensor.math import ( ...@@ -56,6 +56,7 @@ from pytensor.tensor.math import (
ge, ge,
int_div, int_div,
isinf, isinf,
kve,
le, le,
log, log,
log1mexp, log1mexp,
...@@ -3494,3 +3495,18 @@ local_polygamma_to_tri_gamma = PatternNodeRewriter( ...@@ -3494,3 +3495,18 @@ local_polygamma_to_tri_gamma = PatternNodeRewriter(
) )
register_specialize(local_polygamma_to_tri_gamma) register_specialize(local_polygamma_to_tri_gamma)
local_log_kv = PatternNodeRewriter(
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
# During stabilize -x is converted to -1.0 * x
(log, (mul, (kve, "v", "x"), (exp, (mul, -1.0, "x")))),
(sub, (log, (kve, "v", "x")), "x"),
allow_multiple_clients=True,
name="local_log_kv",
# Start the rewrite from the less likely kve node
tracks=[kve],
get_nodes=get_clients_at_depth2,
)
register_stabilize(local_log_kv)
...@@ -61,6 +61,7 @@ from pytensor.tensor.math import ( ...@@ -61,6 +61,7 @@ from pytensor.tensor.math import (
ge, ge,
gt, gt,
int_div, int_div,
kv,
le, le,
log, log,
log1mexp, log1mexp,
...@@ -4578,3 +4579,17 @@ def test_local_batched_matmul_to_core_matmul(): ...@@ -4578,3 +4579,17 @@ def test_local_batched_matmul_to_core_matmul():
x_test = rng.normal(size=(5, 3, 2)) x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(5, 2, 2)) y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
def test_log_kv_stabilization():
x = pt.scalar("x")
out = log(kv(4.5, x))
# Expression would underflow to -inf without rewrite
mode = get_default_mode().including("stabilize")
# Reference value from mpmath
# mpmath.log(mpmath.besselk(4.5, 1000.0))
np.testing.assert_allclose(
out.eval({x: 1000.0}, mode=mode),
-1003.2180912984705,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论