提交 fcd46689 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add `logaddexp`

Closes #467
上级 ecd6a1e9
...@@ -2764,6 +2764,25 @@ def power(x, y): ...@@ -2764,6 +2764,25 @@ def power(x, y):
return x ** y return x ** y
def logaddexp(*xs):
"""Logarithm of the sum of exponentiations of the inputs.
See ``numpy.logaddexp``.
Parameters
----------
xs : symbolic tensors
Input
Returns
-------
tensor
"""
return log(add(*[exp(x) for x in xs]))
def logsumexp(x, axis=None, keepdims=False): def logsumexp(x, axis=None, keepdims=False):
"""Compute the log of the sum of exponentials of input elements. """Compute the log of the sum of exponentials of input elements.
...@@ -2913,5 +2932,6 @@ __all__ = [ ...@@ -2913,5 +2932,6 @@ __all__ = [
"all", "all",
"ptp", "ptp",
"power", "power",
"logaddexp",
"logsumexp", "logsumexp",
] ]
...@@ -76,6 +76,7 @@ from aesara.tensor.math import ( ...@@ -76,6 +76,7 @@ from aesara.tensor.math import (
log1p, log1p,
log2, log2,
log10, log10,
logaddexp,
logsumexp, logsumexp,
max, max,
max_and_argmax, max_and_argmax,
...@@ -125,6 +126,7 @@ from aesara.tensor.type import ( ...@@ -125,6 +126,7 @@ from aesara.tensor.type import (
matrices, matrices,
matrix, matrix,
scalar, scalar,
scalars,
tensor, tensor,
tensor3, tensor3,
tensor4, tensor4,
...@@ -3277,6 +3279,44 @@ def test_tanh_grad_broadcast(): ...@@ -3277,6 +3279,44 @@ def test_tanh_grad_broadcast():
grad(tanh(x + y).sum(), [x, y]) grad(tanh(x + y).sum(), [x, y])
def test_logaddexp():
# Test more than two multidimensional inputs
x, y, z = matrices("x", "y", "z")
out = logaddexp(x, y, z)
f = function([x, y, z], out)
inp = np.zeros((3, 3), dtype=config.floatX)
np.testing.assert_allclose(
f(inp, inp, inp),
np.full((3, 3), np.log(3)),
)
# Test scalar inputs
x, y = scalars("x", "y")
out = logaddexp(x, y)
f = function([x, y], out)
res = f(0, 0)
assert np.ndim(res) == 0
assert np.isclose(res, np.log(2))
# Test scalar and matrix inputs
x = scalar("x")
y = matrix("y")
out = logaddexp(x, y)
f = function([x, y], out)
res = f(
np.array(0, dtype=config.floatX),
np.zeros((3, 3), dtype=config.floatX),
)
assert np.shape(res) == (3, 3)
np.testing.assert_allclose(
res,
np.full((3, 3), np.log(2)),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
["shape", "axis"], ["shape", "axis"],
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论