提交 842bc52a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add benchmark tests for fused Elemwises

上级 2f94d1a8
...@@ -11,6 +11,7 @@ import pytensor.tensor.math as aem ...@@ -11,6 +11,7 @@ import pytensor.tensor.math as aem
from pytensor import config, function from pytensor import config, function
from pytensor.compile.ops import deep_copy_op from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as at_elemwise from pytensor.tensor import elemwise as at_elemwise
...@@ -555,3 +556,18 @@ def test_logsumexp_benchmark(size, axis, benchmark): ...@@ -555,3 +556,18 @@ def test_logsumexp_benchmark(size, axis, benchmark):
res = benchmark(X_lse_fn, X_val) res = benchmark(X_lse_fn, X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res) np.testing.assert_array_almost_equal(res, exp_res)
def test_fused_elemwise_benchmark(benchmark):
rng = np.random.default_rng(123)
size = 100_000
x = pytensor.shared(rng.normal(size=size), name="x")
mu = pytensor.shared(rng.normal(size=size), name="mu")
logp = -((x - mu) ** 2) / 2
grad_logp = grad(logp.sum(), x)
func = pytensor.function([], [logp, grad_logp], mode="NUMBA")
# JIT compile first
func()
benchmark(func)
...@@ -9,6 +9,7 @@ from pytensor import tensor as at ...@@ -9,6 +9,7 @@ from pytensor import tensor as at
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.basic import check_stack_trace, out2in
...@@ -1349,6 +1350,18 @@ class TestFusion: ...@@ -1349,6 +1350,18 @@ class TestFusion:
assert len(nodes) == 1 assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, Composite) assert isinstance(nodes[0].op.scalar_op, Composite)
def test_eval_benchmark(self, benchmark):
rng = np.random.default_rng(123)
size = 100_000
x = pytensor.shared(rng.normal(size=size), name="x")
mu = pytensor.shared(rng.normal(size=size), name="mu")
logp = -((x - mu) ** 2) / 2
grad_logp = grad(logp.sum(), x)
func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN")
benchmark(func)
class TimesN(aes.basic.UnaryScalarOp): class TimesN(aes.basic.UnaryScalarOp):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论