提交 3a6cdb3a authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Add `local_1msigmoid` rewrite by default

上级 e53c49ab
......@@ -3012,7 +3012,6 @@ logsigm_to_softplus = PatternSub(
tracks=[sigmoid],
get_nodes=get_clients_at_depth1,
)
log1msigm_to_softplus = PatternSub(
(log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))),
(neg, (softplus, "x")),
......@@ -3613,38 +3612,16 @@ def local_reciprocal_1_plus_exp(fgraph, node):
return out
# Registration is below, and conditional.
@local_optimizer([sub])
def local_1msigmoid(fgraph, node):
"""
1-sigm(x) -> sigm(-x)
"""
if node.op == sub:
sub_l, sub_r = node.inputs
if len(fgraph.clients[sub_r]) > 1:
return # graph is using both sigm and 1-sigm
if sub_r.owner and sub_r.owner.op == sigmoid:
try:
val_l = get_scalar_constant_value(sub_l)
except NotScalarConstantError:
return
if np.allclose(np.sum(val_l), 1):
out = sigmoid(-sub_r.owner.inputs[0])
copy_stack_trace([sub_r, node.outputs[0]], out)
return [out]
register_local_1msigmoid = False
# This is False because the Stabilize pattern above
# is looking for 1-sigm. Also AlgebraicCanonizer turns neg into *(-1) and so
# this optimization might set off an unwanted chain of things.
# OTH - this transformation can be seen as pushing normal arithmetic either below or above the
# sigmoidal nonlinearity... so if the canonicalized form had anything to say about that then it
# would be a consideration... anyway leaving False for now.
if register_local_1msigmoid:
register_canonicalize(local_1msigmoid)
# 1 - sigmoid(x) -> sigmoid(-x)
local_1msigmoid = PatternSub(
(sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")),
(sigmoid, (neg, "x")),
tracks=[sigmoid],
get_nodes=get_clients_at_depth1,
name="local_1msigmoid",
)
register_stabilize(local_1msigmoid)
register_specialize(local_1msigmoid)
log1pmexp_to_log1mexp = PatternSub(
......
......@@ -83,7 +83,6 @@ from aesara.tensor.math_opt import (
mul_canonizer,
parse_mul_tree,
perform_sigm_times_exp,
register_local_1msigmoid,
simplify_mul,
)
from aesara.tensor.shape import Reshape, Shape_i
......@@ -4210,28 +4209,27 @@ class TestSigmoidOpts:
# Restore config option.
config.warn__identify_1pexp_bug = backup
def test_1msigmoid(self):
if not register_local_1msigmoid:
return
m = self.get_mode()
def test_local_1msigmoid(self):
m = self.get_mode(excluding=["fusion", "inplace"])
x = fmatrix()
# tests exp_over_1_plus_exp
f = aesara.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
assert check_stack_trace(f, ops_to_check=[neg, inplace.sigmoid_inplace])
assert [node.op for node in f.maker.fgraph.toposort()] == [
neg,
inplace.sigmoid_inplace,
]
# FIXME: PatternSub does not copy stack trace
# (see https://github.com/Theano/Theano/issues/4581)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
# tests inv_1_plus_exp
f = aesara.function([x], 1 - aet.fill(x, 1.0) / (1 + exp(-x)), mode=m)
assert check_stack_trace(f, ops_to_check=[neg, inplace.sigmoid_inplace])
assert [node.op for node in f.maker.fgraph.toposort()] == [
neg,
inplace.sigmoid_inplace,
]
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
# Tests float constant
f = aesara.function(
[x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m
)
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
def test_local_sigm_times_exp(self):
# Test the `local_sigm_times_exp` optimization.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论