提交 00a11b60 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Stabilize logdiffexp_to_log1mexpdiff -inf, -inf case

上级 d2120171
......@@ -55,6 +55,7 @@ from pytensor.tensor.math import (
deg2rad,
digamma,
dot,
eq,
erf,
erfc,
exp,
......@@ -3812,12 +3813,14 @@ def logmexpm1_to_log1mexp(fgraph, node):
# log(exp(a) - exp(b)) -> a + log1mexp(b - a)
# special care is taken for a == b == -inf, by wrapping -> switch(b == -inf, a, ...)
logdiffexp_to_log1mexpdiff = PatternNodeRewriter(
(log, (sub, (exp, "x"), (exp, "y"))),
(add, "x", (log1mexp, (sub, "y", "x"))),
(switch, (eq, "y", -np.inf), "x", (add, "x", (log1mexp, (sub, "y", "x")))),
allow_multiple_clients=True,
name="logdiffexp_to_log1mexpdiff",
)
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")
register_stabilize(logdiffexp_to_log1mexpdiff)
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
# i.e logit(sigmoid(x)) -> x
......
......@@ -4581,7 +4581,7 @@ def test_log1mexp_stabilization(op_name):
)
def test_logdiffexp():
def test_logdiffexp_stabilization():
rng = np.random.default_rng(3559)
mode = Mode("py").including("stabilize").excluding("fusion")
......@@ -4618,6 +4618,11 @@ def test_logdiffexp():
np.testing.assert_almost_equal(
f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test))
)
# Test edge cases
np.testing.assert_array_equal(
f([[-np.inf, -np.inf, -1]], [[-1, -np.inf, -np.inf]]),
[[np.nan, -np.inf, -1]],
)
def test_polygamma_specialization():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论