提交 00a11b60 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Stabilize logdiffexp_to_log1mexpdiff -inf, -inf case

上级 d2120171
...@@ -55,6 +55,7 @@ from pytensor.tensor.math import ( ...@@ -55,6 +55,7 @@ from pytensor.tensor.math import (
deg2rad, deg2rad,
digamma, digamma,
dot, dot,
eq,
erf, erf,
erfc, erfc,
exp, exp,
...@@ -3812,12 +3813,14 @@ def logmexpm1_to_log1mexp(fgraph, node): ...@@ -3812,12 +3813,14 @@ def logmexpm1_to_log1mexp(fgraph, node):
# log(exp(a) - exp(b)) -> a + log1mexp(b - a) # log(exp(a) - exp(b)) -> a + log1mexp(b - a)
# special care is taken for a == b == -inf, by wrapping -> switch(b == -inf, a, ...)
logdiffexp_to_log1mexpdiff = PatternNodeRewriter( logdiffexp_to_log1mexpdiff = PatternNodeRewriter(
(log, (sub, (exp, "x"), (exp, "y"))), (log, (sub, (exp, "x"), (exp, "y"))),
(add, "x", (log1mexp, (sub, "y", "x"))), (switch, (eq, "y", -np.inf), "x", (add, "x", (log1mexp, (sub, "y", "x")))),
allow_multiple_clients=True, allow_multiple_clients=True,
name="logdiffexp_to_log1mexpdiff",
) )
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff") register_stabilize(logdiffexp_to_log1mexpdiff)
# log(sigmoid(x) / (1 - sigmoid(x))) -> x # log(sigmoid(x) / (1 - sigmoid(x))) -> x
# i.e logit(sigmoid(x)) -> x # i.e logit(sigmoid(x)) -> x
......
...@@ -4581,7 +4581,7 @@ def test_log1mexp_stabilization(op_name): ...@@ -4581,7 +4581,7 @@ def test_log1mexp_stabilization(op_name):
) )
def test_logdiffexp(): def test_logdiffexp_stabilization():
rng = np.random.default_rng(3559) rng = np.random.default_rng(3559)
mode = Mode("py").including("stabilize").excluding("fusion") mode = Mode("py").including("stabilize").excluding("fusion")
...@@ -4618,6 +4618,11 @@ def test_logdiffexp(): ...@@ -4618,6 +4618,11 @@ def test_logdiffexp():
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test)) f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test))
) )
# Test edge cases
np.testing.assert_array_equal(
f([[-np.inf, -np.inf, -1]], [[-1, -np.inf, -np.inf]]),
[[np.nan, -np.inf, -1]],
)
def test_polygamma_specialization(): def test_polygamma_specialization():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论