提交 f2721a91 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Add benchmark for numba elemwise

上级 b59a0c80
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import pytensor.tensor as at import pytensor.tensor as at
import pytensor.tensor.inplace as ati import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem import pytensor.tensor.math as aem
from pytensor import config from pytensor import config, function
from pytensor.compile.ops import deep_copy_op from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
...@@ -117,6 +117,25 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): ...@@ -117,6 +117,25 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals) compare_numba_and_py(out_fg, input_vals)
def test_elemwise_speed(benchmark):
x = at.dmatrix("y")
y = at.dvector("z")
out = np.exp(2 * x * y + y)
rng = np.random.default_rng(42)
x_val = rng.normal(size=(200, 500))
y_val = rng.normal(size=500)
func = function([x, y], out, mode="NUMBA")
func = func.vm.jit_fn
(out,) = func(x_val, y_val)
np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out)
benchmark(func, x_val, y_val)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, new_order", "v, new_order",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论