提交 d960e510 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Add rewrites to simplify unecessary `logit` - `sigmoid` expressions

上级 15bd3a40
...@@ -3528,3 +3528,29 @@ log1pmexp_to_log1mexp = PatternSub( ...@@ -3528,3 +3528,29 @@ log1pmexp_to_log1mexp = PatternSub(
allow_multiple_clients=True, allow_multiple_clients=True,
) )
register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp") register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp")
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
# i.e logit(sigmoid(x)) -> x
local_logit_sigmoid = PatternSub(
(log, (true_div, (sigmoid, "x"), (sub, 1, (sigmoid, "x")))),
"x",
tracks=[sigmoid],
get_nodes=get_clients_at_depth2,
allow_multiple_clients=True,
name="local_logit_sigmoid",
)
register_canonicalize(local_logit_sigmoid)
register_specialize(local_logit_sigmoid)
# sigmoid(log(x / (1-x)) -> x
# i.e., sigmoid(logit(x)) -> x
local_sigmoid_logit = PatternSub(
(sigmoid, (log, (true_div, "x", (sub, 1, "x")))),
"x",
allow_multiple_clients=True,
name="local_sigmoid_logit",
)
register_canonicalize(local_sigmoid_logit)
register_specialize(local_sigmoid_logit)
...@@ -4507,3 +4507,25 @@ def test_log1mexp_stabilization(): ...@@ -4507,3 +4507,25 @@ def test_log1mexp_stabilization():
f(np.array([-0.8, -0.6], dtype=config.floatX)), f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])), np.log(1 - np.exp([-0.8, -0.6])),
) )
def test_local_logit_sigmoid():
"""
Test that graphs of the form logit(sigmoid(x)) and sigmoid(logit(x)) get
optimized to x (sigmoid is the inverse of the logit)
"""
def logit_fn(x):
return log(x / (1 - x))
x = fmatrix()
out = sigmoid(logit_fn(x))
fg = optimize(FunctionGraph([x], [out]))
assert not list(fg.toposort())
assert fg.inputs[0] is fg.outputs[0]
out = logit_fn(sigmoid(x))
fg = optimize(FunctionGraph([x], [out]))
assert not list(fg.toposort())
assert fg.inputs[0] is fg.outputs[0]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论