Unverified 提交 d3bbc20a authored 作者: Luca Citi's avatar Luca Citi 提交者: GitHub

Cover more cases of `log1mexp` stabilization (#1483)

* Created some tests that fail due to #1476 * Fixes 1476 and other ways to create a log1mexp * Reimplemented logmexpm1_to_log1mexp by tracking expm1 and then looking through the clients * Absorbed the rewrite log1pexp_to_softplus into the new rewrite for log1mexp * Fixed bug where I forgot to check whether result of is_neg was None or not before proceeding --------- Co-authored-by: 's avatarLuca Citi <lciti@ieee.org>
上级 6aeed97c
......@@ -64,6 +64,7 @@ from pytensor.tensor.math import (
log,
log1mexp,
log1p,
log1pexp,
makeKeepDims,
maximum,
mul,
......@@ -2999,12 +3000,6 @@ log1msigm_to_softplus = PatternNodeRewriter(
tracks=[sigmoid],
get_nodes=get_clients_at_depth2,
)
log1pexp_to_softplus = PatternNodeRewriter(
(log1p, (exp, "x")),
(softplus, "x"),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
)
log1p_neg_sigmoid = PatternNodeRewriter(
(log1p, (neg, (sigmoid, "x"))),
(neg, (softplus, "x")),
......@@ -3016,7 +3011,6 @@ log1p_neg_sigmoid = PatternNodeRewriter(
register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus")
register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus")
register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus")
register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid")
register_specialize(log1p_neg_sigmoid, name="log1p_neg_sigmoid")
......@@ -3582,12 +3576,40 @@ register_stabilize(local_1msigmoid)
register_specialize(local_1msigmoid)
log1pmexp_to_log1mexp = PatternNodeRewriter(
(log1p, (neg, (exp, "x"))),
(log1mexp, "x"),
allow_multiple_clients=True,
)
register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp")
@register_stabilize
@node_rewriter([log1p])
def local_log1p_plusminus_exp(fgraph, node):
"""Transforms log1p of ±exp(x) into log1pexp (aka softplus) / log1mexp
``log1p(exp(x)) -> log1pexp(x)``
``log1p(-exp(x)) -> log1mexp(x)``
where "-" can be "neg" or any other expression detected by "is_neg"
"""
(log1p_arg,) = node.inputs
exp_info = is_exp(log1p_arg)
if exp_info is not None:
exp_neg, exp_arg = exp_info
if exp_neg:
return [log1mexp(exp_arg)]
else:
return [log1pexp(exp_arg)] # aka softplus
@register_stabilize
@node_rewriter([expm1])
def logmexpm1_to_log1mexp(fgraph, node):
"""``log(-expm1(x)) -> log1mexp(x)``
where "-" can be "neg" or any other expression detected by "is_neg"
"""
rewrites = {}
for node in get_clients_at_depth(fgraph, node, depth=2):
if node.op == log:
(log_arg,) = node.inputs
neg_arg = is_neg(log_arg)
if neg_arg is not None and neg_arg.owner and neg_arg.owner.op == expm1:
(expm1_arg,) = neg_arg.owner.inputs
rewrites[node.outputs[0]] = log1mexp(expm1_arg)
return rewrites
# log(exp(a) - exp(b)) -> a + log1mexp(b - a)
logdiffexp_to_log1mexpdiff = PatternNodeRewriter(
......
......@@ -4438,11 +4438,22 @@ def test_local_add_neg_to_sub(first_negative):
assert np.allclose(f(x_test, y_test), exp)
def test_log1mexp_stabilization():
@pytest.mark.parametrize(
"op_name",
["log_1_minus_exp", "log1p_minus_exp", "log_minus_expm1", "log_minus_exp_minus_1"],
)
def test_log1mexp_stabilization(op_name):
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
if op_name == "log_1_minus_exp":
f = function([x], log(1 - exp(x)), mode=mode)
elif op_name == "log1p_minus_exp":
f = function([x], log1p(-exp(x)), mode=mode)
elif op_name == "log_minus_expm1":
f = function([x], log(-expm1(x)), mode=mode)
elif op_name == "log_minus_exp_minus_1":
f = function([x], log(-(exp(x) - 1)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [pt.log1mexp]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论