提交 08f49d16 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Safeguard local_log_sum_exp optimization against -inf values and make it non-symmetric

Fixes #461
上级 5b85bca4
...@@ -11,7 +11,6 @@ from functools import reduce ...@@ -11,7 +11,6 @@ from functools import reduce
import numpy as np import numpy as np
import aesara.scalar.basic as aes import aesara.scalar.basic as aes
from aesara import compile
from aesara.assert_op import assert_op from aesara.assert_op import assert_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
...@@ -2312,6 +2311,8 @@ def local_log_add_exp(fgraph, node): ...@@ -2312,6 +2311,8 @@ def local_log_add_exp(fgraph, node):
return [ret] return [ret]
@register_stabilize
@register_specialize
@local_optimizer([log]) @local_optimizer([log])
def local_log_sum_exp(fgraph, node): def local_log_sum_exp(fgraph, node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max))) # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
...@@ -2340,7 +2341,19 @@ def local_log_sum_exp(fgraph, node): ...@@ -2340,7 +2341,19 @@ def local_log_sum_exp(fgraph, node):
max_pre_exp = aet_max(pre_exp, axis=axis) max_pre_exp = aet_max(pre_exp, axis=axis)
max_pre_exp_keepdims = makeKeepDims(pre_exp, max_pre_exp, axis) max_pre_exp_keepdims = makeKeepDims(pre_exp, max_pre_exp, axis)
ret = max_pre_exp + log(aet_sum(exp(pre_exp - max_pre_exp_keepdims), axis=axis)) # Do not offset when max_pre = -np.inf, to avoid nan in the output
# Switch statement is placed directly inside sum to break the self-symmetry
# of the returned output (otherwise the optimization would not stabilize)
ret = max_pre_exp + log(
aet_sum(
switch(
isinf(max_pre_exp_keepdims),
exp(max_pre_exp_keepdims),
exp(pre_exp - max_pre_exp_keepdims),
),
axis=axis,
),
)
# Restore the dimshuffle op, if any. # Restore the dimshuffle op, if any.
if dimshuffle_op: if dimshuffle_op:
...@@ -2349,14 +2362,6 @@ def local_log_sum_exp(fgraph, node): ...@@ -2349,14 +2362,6 @@ def local_log_sum_exp(fgraph, node):
return [ret] return [ret]
compile.optdb.register(
"local_log_sum_exp",
in2out(local_log_sum_exp, ignore_newtrees=True),
1.6,
"fast_run",
)
def add_calculate(num, denum, aslist=False, out_type=None): def add_calculate(num, denum, aslist=False, out_type=None):
# TODO: make sure that this function and mul_calculate are similar # TODO: make sure that this function and mul_calculate are similar
if out_type is None: if out_type is None:
......
...@@ -4000,6 +4000,16 @@ def test_local_log_sum_exp3(): ...@@ -4000,6 +4000,16 @@ def test_local_log_sum_exp3():
assert np.allclose(optimised_ret, 100.0) assert np.allclose(optimised_ret, 100.0)
def test_local_log_sum_exp_inf():
# Test that when max = +-inf, optimized output still works correctly
x = vector("x")
f = compile_graph_log_sum_exp(x, axis=0)
assert f([-np.inf, -np.inf]) == -np.inf
assert f([np.inf, np.inf]) == np.inf
assert f([-np.inf, np.inf]) == np.inf
def test_local_reciprocal_1_plus_exp(): def test_local_reciprocal_1_plus_exp():
x = vector("x") x = vector("x")
y = aet.reciprocal(1 + exp(x)) y = aet.reciprocal(1 + exp(x))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论