提交 5b85bca4 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Safeguard local_log_add_exp against -inf and extend it to more than 2 inputs

Fixes #461
上级 b84ac43a
......@@ -74,6 +74,7 @@ from aesara.tensor.math import (
expm1,
ge,
int_div,
isinf,
log,
log1p,
makeKeepDims,
......@@ -2286,31 +2287,28 @@ def local_log1p(fgraph, node):
@register_stabilize
@register_specialize
@local_optimizer([log])
def local_log_add(fgraph, node):
# log(exp(x)+exp(y))
#
# Suppose x >= y
# log(exp(x) + exp(y))
# log(exp(x) * (1 + exp(y)/exp(x)))
# x + log(1 + exp(y)/exp(x))
# x + log1p(exp(y)/exp(x))
# x + log1p(exp(y-x))
def local_log_add_exp(fgraph, node):
# log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)
if node.op == log:
z = node.inputs[0]
if z.owner and z.owner.op == add:
zi = z.owner.inputs
if len(zi) != 2:
# -- upgrading Maximum to handle multiple inputs wasn't trivial
# TODO
# raise NotImplementedError()
return
pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp]
# all arguments to add are exp(<something>)
if len(pre_exp) == len(zi):
# all arguments to add are exp(<something>)
max_pre = maximum(*pre_exp)
ret = max_pre + log1p(exp(add(*[p - max_pre for p in pre_exp])))
ret.tag.values_eq_approx = values_eq_approx_remove_inf
# Do not offset when max_pre = -np.inf, to avoid nan in the output
# Switch statement is placed directly inside add to break the self-symmetry
# of the returned output (otherwise the optimization would not stabilize)
max_pre = reduce(maximum, pre_exp)
ret = max_pre + log(
add(
*[
switch(isinf(max_pre), exp(max_pre), exp(p - max_pre))
for p in pre_exp
]
)
)
return [ret]
......
......@@ -1840,10 +1840,7 @@ def test_log1p():
assert [node.op for node in f.maker.fgraph.toposort()] == [log1p]
@pytest.mark.xfail(
reason="log(add(exp)) is not stabilized when adding more than 2 elements, see #623"
)
def test_log_add():
def test_local_log_add_exp():
m = config.mode
if m == "FAST_COMPILE":
m = "FAST_RUN"
......@@ -1858,26 +1855,28 @@ def test_log_add():
y = dvector()
f = function([x, y], log(exp(x) + exp(y)), mode=m)
f([10000], [10000]) # causes overflow if handled incorrectly
assert np.isfinite(f([10000], [10000]))
# test that it gives the correct result when it doesn't overflow
f([10], [10]) # doesn't causes overflow
utt.assert_allclose(f([10], [10]), 10 + np.log1p(1))
assert np.isfinite(f([10000], [10000])) # causes overflow if handled incorrectly
utt.assert_allclose(f([10000], [10000]), 10000 + np.log1p(1))
# test that it give the same result when it don't overflow
f([10], [10]) # don't causes overflow
utt.assert_allclose(f([10], [10]), 10 + np.log1p(1))
# test that when max = +-inf, optimized output still works correctly
assert f([-np.inf], [-np.inf]) == -np.inf
assert f([np.inf], [np.inf]) == np.inf
assert f([np.inf], [-np.inf]) == np.inf
# test that it also works with more than two args, (this currently fails)
# test that it also works with more than two args
x = dvector()
y = dvector()
f = function([x, y], log(exp(x) + exp(y) + exp(x - y) + exp(x + y)), mode=m)
f([10000], [10000]) # causes overflow if handled incorrectly
assert np.isfinite(f([10000], [10000])) # causes overflow if handled incorrectly
utt.assert_allclose(f([10000], [10000]), 20000)
# TODO: test that the optimization works in the presence of broadcasting.
# TODO: (write and) test that the optimization works with Sum in addition to working with Add.
def test_local_subtensor_of_dot():
m1 = matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论