提交 3a6cdb3a authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Add `local_1msigmoid` rewrite by default

上级 e53c49ab
...@@ -3012,7 +3012,6 @@ logsigm_to_softplus = PatternSub( ...@@ -3012,7 +3012,6 @@ logsigm_to_softplus = PatternSub(
tracks=[sigmoid], tracks=[sigmoid],
get_nodes=get_clients_at_depth1, get_nodes=get_clients_at_depth1,
) )
log1msigm_to_softplus = PatternSub( log1msigm_to_softplus = PatternSub(
(log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))), (log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))),
(neg, (softplus, "x")), (neg, (softplus, "x")),
...@@ -3613,38 +3612,16 @@ def local_reciprocal_1_plus_exp(fgraph, node): ...@@ -3613,38 +3612,16 @@ def local_reciprocal_1_plus_exp(fgraph, node):
return out return out
# Registration is below, and conditional. # 1 - sigmoid(x) -> sigmoid(-x)
@local_optimizer([sub]) local_1msigmoid = PatternSub(
def local_1msigmoid(fgraph, node): (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")),
""" (sigmoid, (neg, "x")),
1-sigm(x) -> sigm(-x) tracks=[sigmoid],
get_nodes=get_clients_at_depth1,
""" name="local_1msigmoid",
if node.op == sub: )
sub_l, sub_r = node.inputs register_stabilize(local_1msigmoid)
if len(fgraph.clients[sub_r]) > 1: register_specialize(local_1msigmoid)
return # graph is using both sigm and 1-sigm
if sub_r.owner and sub_r.owner.op == sigmoid:
try:
val_l = get_scalar_constant_value(sub_l)
except NotScalarConstantError:
return
if np.allclose(np.sum(val_l), 1):
out = sigmoid(-sub_r.owner.inputs[0])
copy_stack_trace([sub_r, node.outputs[0]], out)
return [out]
register_local_1msigmoid = False
# This is False because the Stabilize pattern above
# is looking for 1-sigm. Also AlgebraicCanonizer turns neg into *(-1) and so
# this optimization might set off an unwanted chain of things.
# OTH - this transformation can be seen as pushing normal arithmetic either below or above the
# sigmoidal nonlinearity... so if the canonicalized form had anything to say about that then it
# would be a consideration... anyway leaving False for now.
if register_local_1msigmoid:
register_canonicalize(local_1msigmoid)
log1pmexp_to_log1mexp = PatternSub( log1pmexp_to_log1mexp = PatternSub(
......
...@@ -83,7 +83,6 @@ from aesara.tensor.math_opt import ( ...@@ -83,7 +83,6 @@ from aesara.tensor.math_opt import (
mul_canonizer, mul_canonizer,
parse_mul_tree, parse_mul_tree,
perform_sigm_times_exp, perform_sigm_times_exp,
register_local_1msigmoid,
simplify_mul, simplify_mul,
) )
from aesara.tensor.shape import Reshape, Shape_i from aesara.tensor.shape import Reshape, Shape_i
...@@ -4210,28 +4209,27 @@ class TestSigmoidOpts: ...@@ -4210,28 +4209,27 @@ class TestSigmoidOpts:
# Restore config option. # Restore config option.
config.warn__identify_1pexp_bug = backup config.warn__identify_1pexp_bug = backup
def test_1msigmoid(self): def test_local_1msigmoid(self):
if not register_local_1msigmoid: m = self.get_mode(excluding=["fusion", "inplace"])
return
m = self.get_mode()
x = fmatrix() x = fmatrix()
# tests exp_over_1_plus_exp # tests exp_over_1_plus_exp
f = aesara.function([x], 1 - exp(x) / (1 + exp(x)), mode=m) f = aesara.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
assert check_stack_trace(f, ops_to_check=[neg, inplace.sigmoid_inplace]) # FIXME: PatternSub does not copy stack trace
assert [node.op for node in f.maker.fgraph.toposort()] == [ # (see https://github.com/Theano/Theano/issues/4581)
neg, # assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
inplace.sigmoid_inplace, assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
]
# tests inv_1_plus_exp # tests inv_1_plus_exp
f = aesara.function([x], 1 - aet.fill(x, 1.0) / (1 + exp(-x)), mode=m) f = aesara.function([x], 1 - aet.fill(x, 1.0) / (1 + exp(-x)), mode=m)
assert check_stack_trace(f, ops_to_check=[neg, inplace.sigmoid_inplace]) # assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [ assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
neg,
inplace.sigmoid_inplace, # Tests float constant
] f = aesara.function(
[x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m
)
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
def test_local_sigm_times_exp(self): def test_local_sigm_times_exp(self):
# Test the `local_sigm_times_exp` optimization. # Test the `local_sigm_times_exp` optimization.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论