提交 f2721a91 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Add benchmark for numba elemwise

上级 b59a0c80
......@@ -6,7 +6,7 @@ import pytest
import pytensor.tensor as at
import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem
from pytensor import config
from pytensor import config, function
from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
......@@ -117,6 +117,25 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals)
def test_elemwise_speed(benchmark):
x = at.dmatrix("y")
y = at.dvector("z")
out = np.exp(2 * x * y + y)
rng = np.random.default_rng(42)
x_val = rng.normal(size=(200, 500))
y_val = rng.normal(size=500)
func = function([x, y], out, mode="NUMBA")
func = func.vm.jit_fn
(out,) = func(x_val, y_val)
np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out)
benchmark(func, x_val, y_val)
@pytest.mark.parametrize(
"v, new_order",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论