提交 842bc52a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add benchmark tests for fused Elemwises

上级 2f94d1a8
......@@ -11,6 +11,7 @@ import pytensor.tensor.math as aem
from pytensor import config, function
from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as at_elemwise
......@@ -555,3 +556,18 @@ def test_logsumexp_benchmark(size, axis, benchmark):
res = benchmark(X_lse_fn, X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
def test_fused_elemwise_benchmark(benchmark):
rng = np.random.default_rng(123)
size = 100_000
x = pytensor.shared(rng.normal(size=size), name="x")
mu = pytensor.shared(rng.normal(size=size), name="mu")
logp = -((x - mu) ** 2) / 2
grad_logp = grad(logp.sum(), x)
func = pytensor.function([], [logp, grad_logp], mode="NUMBA")
# JIT compile first
func()
benchmark(func)
......@@ -9,6 +9,7 @@ from pytensor import tensor as at
from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
......@@ -1349,6 +1350,18 @@ class TestFusion:
assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, Composite)
def test_eval_benchmark(self, benchmark):
rng = np.random.default_rng(123)
size = 100_000
x = pytensor.shared(rng.normal(size=size), name="x")
mu = pytensor.shared(rng.normal(size=size), name="mu")
logp = -((x - mu) ** 2) / 2
grad_logp = grad(logp.sum(), x)
func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN")
benchmark(func)
class TimesN(aes.basic.UnaryScalarOp):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论