提交 8717b8f7 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add stabilize opt for `log1mexp`

上级 f7d1bb3b
...@@ -77,6 +77,7 @@ from aesara.tensor.math import ( ...@@ -77,6 +77,7 @@ from aesara.tensor.math import (
int_div, int_div,
isinf, isinf,
log, log,
log1mexp,
log1p, log1p,
makeKeepDims, makeKeepDims,
) )
...@@ -3664,3 +3665,11 @@ register_local_1msigmoid = False ...@@ -3664,3 +3665,11 @@ register_local_1msigmoid = False
if register_local_1msigmoid: if register_local_1msigmoid:
register_canonicalize(local_1msigmoid) register_canonicalize(local_1msigmoid)
log1pmexp_to_log1mexp = PatternSub(
(log1p, (neg, (exp, "x"))),
(log1mexp, "x"),
allow_multiple_clients=True,
)
register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp")
...@@ -4439,3 +4439,24 @@ class TestSigmoidUtils: ...@@ -4439,3 +4439,24 @@ class TestSigmoidUtils:
assert is_1pexp(1 + 2 * exp_op(x), False) is None assert is_1pexp(1 + 2 * exp_op(x), False) is None
finally: finally:
config.warn__identify_1pexp_bug = backup config.warn__identify_1pexp_bug = backup
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [aet.log1mexp]
# Check values that would under or overflow without optimization
assert f([-(2.0 ** -55)]) != -np.inf
overflow_value = -500.0 if config.floatX == "float64" else -100.0
assert f([overflow_value]) < 0
# Check values around the optimization switch point np.log(0.5)
assert np.allclose(
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论