提交 9fd3af71 authored 作者: Smit-create's avatar Smit-create 提交者: Ricardo Vieira

Add stabilization rewrite for log_diff_exp

上级 b1332b27
...@@ -3604,6 +3604,14 @@ log1pmexp_to_log1mexp = PatternNodeRewriter( ...@@ -3604,6 +3604,14 @@ log1pmexp_to_log1mexp = PatternNodeRewriter(
) )
register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp") register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp")
# log(exp(a) - exp(b)) -> a + log1mexp(b - a)
logdiffexp_to_log1mexpdiff = PatternNodeRewriter(
(log, (sub, (exp, "x"), (exp, "y"))),
(add, "x", (log1mexp, (sub, "y", "x"))),
allow_multiple_clients=True,
)
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")
# log(sigmoid(x) / (1 - sigmoid(x))) -> x # log(sigmoid(x) / (1 - sigmoid(x))) -> x
# i.e logit(sigmoid(x)) -> x # i.e logit(sigmoid(x)) -> x
......
...@@ -4136,3 +4136,42 @@ def test_log1mexp_stabilization(): ...@@ -4136,3 +4136,42 @@ def test_log1mexp_stabilization():
f(np.array([-0.8, -0.6], dtype=config.floatX)), f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])), np.log(1 - np.exp([-0.8, -0.6])),
) )
def test_logdiffexp():
rng = np.random.default_rng(3559)
mode = Mode("py").including("stabilize").excluding("fusion")
x = fmatrix("x")
y = fmatrix("y")
f = function([x, y], log(exp(x) - exp(y)), mode=mode)
graph = f.maker.fgraph.toposort()
assert (
len(
[
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Exp, aes.Log))
]
)
== 0
)
assert (
len(
[
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log1mexp)
]
)
== 1
)
y_test = rng.normal(size=(3, 2)).astype("float32")
x_test = rng.normal(size=(3, 2)).astype("float32") + y_test.max()
np.testing.assert_almost_equal(
f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test))
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论