提交 f3a4f2b6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Seed logsumexp benchmark tests

Also adds missing numba benchmark test Co-authored-by: 's avatarBrandon T. Willard <brandonwillard@users.noreply.github.com>
上级 b8831aa7
...@@ -111,7 +111,8 @@ def test_logsumexp_benchmark(size, axis, benchmark): ...@@ -111,7 +111,8 @@ def test_logsumexp_benchmark(size, axis, benchmark):
X_max = at.switch(at.isinf(X_max), 0, X_max) X_max = at.switch(at.isinf(X_max), 0, X_max)
X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max
X_val = np.random.normal(size=size) rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)
X_lse_fn = pytensor.function([X], X_lse, mode="JAX") X_lse_fn = pytensor.function([X], X_lse, mode="JAX")
......
...@@ -2,7 +2,9 @@ import contextlib ...@@ -2,7 +2,9 @@ import contextlib
import numpy as np import numpy as np
import pytest import pytest
import scipy.special
import pytensor
import pytensor.tensor as at import pytensor.tensor as at
import pytensor.tensor.inplace as ati import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem import pytensor.tensor.math as aem
...@@ -532,3 +534,24 @@ def test_MaxAndArgmax(x, axes, exc): ...@@ -532,3 +534,24 @@ def test_MaxAndArgmax(x, axes, exc):
if not isinstance(i, (SharedVariable, Constant)) if not isinstance(i, (SharedVariable, Constant))
], ],
) )
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
X = at.matrix("X")
X_max = at.max(X, axis=axis, keepdims=True)
X_max = at.switch(at.isinf(X_max), 0, X_max)
X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max
rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)
X_lse_fn = pytensor.function([X], X_lse, mode="JAX")
# JIT compile first
_ = X_lse_fn(X_val)
res = benchmark(X_lse_fn, X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论