提交 00a8a883 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add benchmark test for CAReduce

上级 a8303a0d
...@@ -985,3 +985,29 @@ class TestVectorize: ...@@ -985,3 +985,29 @@ class TestVectorize:
assert isinstance(vect_node.op, Any) assert isinstance(vect_node.op, Any)
assert vect_node.op.axis == (1,) assert vect_node.op.axis == (1,)
assert vect_node.inputs[0] is bool_tns assert vect_node.inputs[0] is bool_tns
@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_careduce_benchmark(axis, c_contiguous, benchmark):
N = 256
x_test = np.random.uniform(size=(N, N, N))
transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
x = pytensor.shared(x_test, name="x", shape=x_test.shape)
out = x.transpose(transpose_axis).sum(axis=axis)
fn = pytensor.function([], out)
np.testing.assert_allclose(
fn(),
x_test.transpose(transpose_axis).sum(axis=axis),
)
benchmark(fn)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论