提交 b5313f1e authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Separate local_exp_log

Separate opts that require a `nan_switch` from those that do not, to avoid unnecessarily complicated graphs with nested log(exp(...)) forms
上级 091eec50
......@@ -249,6 +249,30 @@ def local_exp_log(fgraph, node):
if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log):
return x.owner.inputs
# Case for exp(softplus(x)) aka exp(log1pexp)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = add(1, exp(x))
if new_out.type != old_out.type:
return
return [new_out]
@register_specialize
@local_optimizer([Elemwise])
def local_exp_log_nan_switch(fgraph, node):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
x = node.inputs[0]
if not isinstance(node.op, Elemwise):
return
if not x.owner or not isinstance(x.owner.op, Elemwise):
return
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
# Case for exp(log(x))
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
......@@ -276,15 +300,6 @@ def local_exp_log(fgraph, node):
return
return [new_out]
# Case for exp(softplus(x)) aka exp(log1pexp)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = add(1, exp(x))
if new_out.type != old_out.type:
return
return [new_out]
@register_canonicalize
@register_specialize
......
......@@ -2493,7 +2493,10 @@ class TestFuncInverse:
class TestExpLog:
def setup_method(self):
mode = get_default_mode()
self.mode = mode.including("local_exp_log").excluding("fusion")
self.mode = mode.including(
"local_exp_log",
"local_exp_log_nan_switch",
).excluding("fusion")
def test_log_exp(self):
# log(exp(x)) -> x
......@@ -2583,6 +2586,26 @@ class TestExpLog:
decimal=6,
)
@pytest.mark.parametrize(
["nested_expression", "expected_switches"],
[
(lambda x: exp(log(exp(log(exp(x))))), 0),
(lambda x: exp(log(exp(log(x)))), 1),
],
)
def test_exp_log_nested(self, nested_expression, expected_switches):
# Make sure nested exp-log graphs have as little `nan` switches as necessary
x = fvector()
f = function([x], nested_expression(x), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Switch)
]
assert len(ops_graph) == expected_switches
class TestLocalSwitchSink:
def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论