提交 f3a4f2b6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Seed logsumexp benchmark tests

Also adds missing numba benchmark test Co-authored-by: 's avatarBrandon T. Willard <brandonwillard@users.noreply.github.com>
上级 b8831aa7
......@@ -111,7 +111,8 @@ def test_logsumexp_benchmark(size, axis, benchmark):
X_max = at.switch(at.isinf(X_max), 0, X_max)
X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max
X_val = np.random.normal(size=size)
rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)
X_lse_fn = pytensor.function([X], X_lse, mode="JAX")
......
......@@ -2,7 +2,9 @@ import contextlib
import numpy as np
import pytest
import scipy.special
import pytensor
import pytensor.tensor as at
import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem
......@@ -532,3 +534,24 @@ def test_MaxAndArgmax(x, axes, exc):
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
X = at.matrix("X")
X_max = at.max(X, axis=axis, keepdims=True)
X_max = at.switch(at.isinf(X_max), 0, X_max)
X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max
rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)
X_lse_fn = pytensor.function([X], X_lse, mode="JAX")
# JIT compile first
_ = X_lse_fn(X_val)
res = benchmark(X_lse_fn, X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论