提交 e476ebcd authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Fix limit of `log1mexp` gradient at zero and improve numerical precision

上级 cfa867b3
...@@ -20,10 +20,13 @@ from aesara.scalar.basic import ( ...@@ -20,10 +20,13 @@ from aesara.scalar.basic import (
complex_types, complex_types,
discrete_types, discrete_types,
exp, exp,
expm1,
float64, float64,
float_types, float_types,
isinf,
log, log,
log1p, log1p,
switch,
true_div, true_div,
upcast, upcast,
upgrade_to_float, upgrade_to_float,
...@@ -1201,7 +1204,10 @@ class Log1mexp(UnaryScalarOp): ...@@ -1201,7 +1204,10 @@ class Log1mexp(UnaryScalarOp):
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
(gz,) = grads (gz,) = grads
return [gz * true_div(1.0, 1.0 - exp(-x))] res = true_div(-1.0, expm1(-x))
# Correct gradient at 0.0 to be -inf
res = switch(isinf(res), -np.inf, res)
return [gz * res]
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
(x,) = inp (x,) = inp
......
...@@ -74,6 +74,7 @@ from aesara.tensor.math import ( ...@@ -74,6 +74,7 @@ from aesara.tensor.math import (
isnan, isnan,
isnan_, isnan_,
log, log,
log1mexp,
log1p, log1p,
log2, log2,
log10, log10,
...@@ -3343,3 +3344,13 @@ def test_pprint(): ...@@ -3343,3 +3344,13 @@ def test_pprint():
x = vector("x") x = vector("x")
y = aet_sum(x, axis=0) y = aet_sum(x, axis=0)
assert pprint(y) == "sum(x, axis=(0,))" assert pprint(y) == "sum(x, axis=(0,))"
def test_log1mexp_grad_lim():
x = dscalar("x")
grad_x = grad(log1mexp(x), [x])[0]
grad_x_fn = function([x], grad_x)
assert grad_x_fn(0.0) == -np.inf
assert grad_x_fn(-0.0) == -np.inf
assert grad_x_fn(-1e-309) == -np.inf
assert grad_x_fn(-1e-308) != -np.inf
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论