提交 e476ebcd authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Fix limit of `log1mexp` gradient at zero and improve numerical precision

上级 cfa867b3
......@@ -20,10 +20,13 @@ from aesara.scalar.basic import (
complex_types,
discrete_types,
exp,
expm1,
float64,
float_types,
isinf,
log,
log1p,
switch,
true_div,
upcast,
upgrade_to_float,
......@@ -1201,7 +1204,10 @@ class Log1mexp(UnaryScalarOp):
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
return [gz * true_div(1.0, 1.0 - exp(-x))]
res = true_div(-1.0, expm1(-x))
# Correct gradient at 0.0 to be -inf
res = switch(isinf(res), -np.inf, res)
return [gz * res]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
......
......@@ -74,6 +74,7 @@ from aesara.tensor.math import (
isnan,
isnan_,
log,
log1mexp,
log1p,
log2,
log10,
......@@ -3343,3 +3344,13 @@ def test_pprint():
x = vector("x")
y = aet_sum(x, axis=0)
assert pprint(y) == "sum(x, axis=(0,))"
def test_log1mexp_grad_lim():
x = dscalar("x")
grad_x = grad(log1mexp(x), [x])[0]
grad_x_fn = function([x], grad_x)
assert grad_x_fn(0.0) == -np.inf
assert grad_x_fn(-0.0) == -np.inf
assert grad_x_fn(-1e-309) == -np.inf
assert grad_x_fn(-1e-308) != -np.inf
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论