提交 86cc3b87 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Benchmark special vector case in GEMV

上级 a7b46524
...@@ -413,6 +413,32 @@ class TestBlasStridesC(TestBlasStrides): ...@@ -413,6 +413,32 @@ class TestBlasStridesC(TestBlasStrides):
mode = mode_blas_opt mode = mode_blas_opt
def test_gemv_vector_dot_perf(benchmark):
n = 400_000
a = pt.vector("A", shape=(n,))
b = pt.vector("x", shape=(n,))
out = CGemv(inplace=True)(
pt.empty((1,)),
1.0,
a[None],
b,
0.0,
)
fn = pytensor.function([a, b], out, accept_inplace=True, trust_input=True)
rng = np.random.default_rng(430)
test_a = rng.normal(size=n)
test_b = rng.normal(size=n)
np.testing.assert_allclose(
fn(test_a, test_b),
np.dot(test_a, test_b),
)
benchmark(fn, test_a, test_b)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"neg_stride1", (True, False), ids=["neg_stride1", "pos_stride1"] "neg_stride1", (True, False), ids=["neg_stride1", "pos_stride1"]
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论