提交 f7d1bb3b authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Implement `log1mexp` op

Closes #360
上级 16c2c5cf
......@@ -20,6 +20,7 @@ from aesara.scalar.basic import (
exp,
float64,
float_types,
true_div,
upcast,
upgrade_to_float,
upgrade_to_float64,
......@@ -997,3 +998,49 @@ class Softplus(UnaryScalarOp):
softplus = Softplus(upgrade_to_float, name="scalar_softplus")
class Log1mexp(UnaryScalarOp):
r"""
Compute log(1 - exp(x)), also known as log1mexp
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
"""
@staticmethod
def static_impl(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
else:
return np.log(-np.expm1(x))
def impl(self, x):
return Log1mexp.static_impl(x)
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
return [gz * true_div(1.0, 1.0 - exp(-x))]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
if node.inputs[0].type in float_types:
if node.inputs[0].type == float64:
return f"{z} = {x} < -0.6931471805599453 ? log1p(-exp({x})) : log(-expm1({x}));"
else:
return f"{z} = {x} < -0.6931471805599453f ? log1p(-exp({x})) : log(-expm1({x}));"
else:
raise NotImplementedError("only floating point is implemented")
log1mexp = Log1mexp(upgrade_to_float, name="scalar_log1mexp")
......@@ -318,6 +318,11 @@ def softplus_inplace(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
@scalar_elemwise
def log1mexp_inplace(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""
@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
......
......@@ -1424,6 +1424,11 @@ def softplus(x):
log1pexp = softplus
@scalar_elemwise
def log1mexp(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""
@scalar_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
......@@ -2903,6 +2908,7 @@ __all__ = [
"expit",
"softplus",
"log1pexp",
"log1mexp",
"real",
"imag",
"angle",
......
......@@ -567,6 +567,39 @@ class TestSoftplus:
np.testing.assert_allclose(y_th, y_np, rtol=10e-10)
_good_broadcast_unary_log1mexp = dict(
normal=(random_ranged(-10.0, 0, (2, 3)),),
float32=(random_ranged(-10.0, 0, (2, 3)).astype("float32"),),
empty=(np.asarray([], dtype=config.floatX),),
int=(integers_ranged(-10, -1, (2, 3)),),
)
_grad_broadcast_unary_log1mexp = dict(
normal=(random_ranged(-10.0, 0.0, (2, 3)),),
)
def expected_log1mexp(x):
return check_floatX(x, np.log(-np.expm1(x)))
TestLog1mexpBroadcast = makeBroadcastTester(
op=aet.log1mexp,
expected=expected_log1mexp,
good=_good_broadcast_unary_log1mexp,
grad=_grad_broadcast_unary_log1mexp,
eps=1e-8,
)
TestLog1mexpInplaceBroadcast = makeBroadcastTester(
op=inplace.log1mexp_inplace,
expected=expected_log1mexp,
good=_good_broadcast_unary_log1mexp,
eps=1e-8,
inplace=True,
)
def test_deprecated_module():
with pytest.warns(DeprecationWarning):
import aesara.scalar.basic_scipy # noqa: F401
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论