提交 5b85bca4 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Safeguard local_log_add_exp against -inf and extend it to more than 2 inputs

Fixes #461
上级 b84ac43a
...@@ -74,6 +74,7 @@ from aesara.tensor.math import ( ...@@ -74,6 +74,7 @@ from aesara.tensor.math import (
expm1, expm1,
ge, ge,
int_div, int_div,
isinf,
log, log,
log1p, log1p,
makeKeepDims, makeKeepDims,
...@@ -2286,31 +2287,28 @@ def local_log1p(fgraph, node): ...@@ -2286,31 +2287,28 @@ def local_log1p(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([log]) @local_optimizer([log])
def local_log_add(fgraph, node): def local_log_add_exp(fgraph, node):
# log(exp(x)+exp(y)) # log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)
#
# Suppose x >= y
# log(exp(x) + exp(y))
# log(exp(x) * (1 + exp(y)/exp(x)))
# x + log(1 + exp(y)/exp(x))
# x + log1p(exp(y)/exp(x))
# x + log1p(exp(y-x))
if node.op == log: if node.op == log:
z = node.inputs[0] z = node.inputs[0]
if z.owner and z.owner.op == add: if z.owner and z.owner.op == add:
zi = z.owner.inputs zi = z.owner.inputs
if len(zi) != 2:
# -- upgrading Maximum to handle multiple inputs wasn't trivial
# TODO
# raise NotImplementedError()
return
pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp] pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp]
# all arguments to add are exp(<something>)
if len(pre_exp) == len(zi): if len(pre_exp) == len(zi):
# all arguments to add are exp(<something>) # Do not offset when max_pre = -np.inf, to avoid nan in the output
max_pre = maximum(*pre_exp) # Switch statement is placed directly inside add to break the self-symmetry
# of the returned output (otherwise the optimization would not stabilize)
ret = max_pre + log1p(exp(add(*[p - max_pre for p in pre_exp]))) max_pre = reduce(maximum, pre_exp)
ret.tag.values_eq_approx = values_eq_approx_remove_inf ret = max_pre + log(
add(
*[
switch(isinf(max_pre), exp(max_pre), exp(p - max_pre))
for p in pre_exp
]
)
)
return [ret] return [ret]
......
...@@ -1840,10 +1840,7 @@ def test_log1p(): ...@@ -1840,10 +1840,7 @@ def test_log1p():
assert [node.op for node in f.maker.fgraph.toposort()] == [log1p] assert [node.op for node in f.maker.fgraph.toposort()] == [log1p]
@pytest.mark.xfail( def test_local_log_add_exp():
reason="log(add(exp)) is not stabilized when adding more than 2 elements, see #623"
)
def test_log_add():
m = config.mode m = config.mode
if m == "FAST_COMPILE": if m == "FAST_COMPILE":
m = "FAST_RUN" m = "FAST_RUN"
...@@ -1858,26 +1855,28 @@ def test_log_add(): ...@@ -1858,26 +1855,28 @@ def test_log_add():
y = dvector() y = dvector()
f = function([x, y], log(exp(x) + exp(y)), mode=m) f = function([x, y], log(exp(x) + exp(y)), mode=m)
f([10000], [10000]) # causes overflow if handled incorrectly # test that it gives the correct result when it doesn't overflow
assert np.isfinite(f([10000], [10000])) f([10], [10]) # doesn't causes overflow
utt.assert_allclose(f([10], [10]), 10 + np.log1p(1))
assert np.isfinite(f([10000], [10000])) # causes overflow if handled incorrectly
utt.assert_allclose(f([10000], [10000]), 10000 + np.log1p(1)) utt.assert_allclose(f([10000], [10000]), 10000 + np.log1p(1))
# test that it give the same result when it don't overflow # test that when max = +-inf, optimized output still works correctly
f([10], [10]) # don't causes overflow assert f([-np.inf], [-np.inf]) == -np.inf
utt.assert_allclose(f([10], [10]), 10 + np.log1p(1)) assert f([np.inf], [np.inf]) == np.inf
assert f([np.inf], [-np.inf]) == np.inf
# test that it also works with more than two args, (this currently fails) # test that it also works with more than two args
x = dvector() x = dvector()
y = dvector() y = dvector()
f = function([x, y], log(exp(x) + exp(y) + exp(x - y) + exp(x + y)), mode=m) f = function([x, y], log(exp(x) + exp(y) + exp(x - y) + exp(x + y)), mode=m)
f([10000], [10000]) # causes overflow if handled incorrectly assert np.isfinite(f([10000], [10000])) # causes overflow if handled incorrectly
utt.assert_allclose(f([10000], [10000]), 20000) utt.assert_allclose(f([10000], [10000]), 20000)
# TODO: test that the optimization works in the presence of broadcasting. # TODO: test that the optimization works in the presence of broadcasting.
# TODO: (write and) test that the optimization works with Sum in addition to working with Add.
def test_local_subtensor_of_dot(): def test_local_subtensor_of_dot():
m1 = matrix() m1 = matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论