提交 ac6dc81b authored 作者: ricardoV94's avatar ricardoV94 提交者: Jesse Grabowski

Group local_block_diag_dot_to_dot_block_diag tests

上级 e75bbb2c
...@@ -4857,16 +4857,15 @@ def test_local_dot_to_mul_unspecified_length_1(): ...@@ -4857,16 +4857,15 @@ def test_local_dot_to_mul_unspecified_length_1():
) )
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) class TestBlockDiagDotToDotBlockDiag:
@pytest.mark.parametrize( @pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"] "batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_other", [True, False], ids=["batched_other", "unbatched_other"] "batch_other", [True, False], ids=["batched_other", "unbatched_other"]
) )
def test_local_block_diag_dot_to_dot_block_diag( def test_rewrite_applies(self, left_multiply, batch_blockdiag, batch_other):
left_multiply, batch_blockdiag, batch_other
):
""" """
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
""" """
...@@ -4905,7 +4904,10 @@ def test_local_block_diag_dot_to_dot_block_diag( ...@@ -4905,7 +4904,10 @@ def test_local_block_diag_dot_to_dot_block_diag(
n_dots_rewrite = sum( n_dots_rewrite = sum(
isinstance(node.op, Dot | Dot22) isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Dot | Dot22)
)
for node in fn.maker.fgraph.apply_nodes for node in fn.maker.fgraph.apply_nodes
) )
assert n_dots_rewrite == 3 assert n_dots_rewrite == 3
...@@ -4919,7 +4921,10 @@ def test_local_block_diag_dot_to_dot_block_diag( ...@@ -4919,7 +4921,10 @@ def test_local_block_diag_dot_to_dot_block_diag(
n_dots_no_rewrite = sum( n_dots_no_rewrite = sum(
isinstance(node.op, Dot | Dot22) isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Dot | Dot22)
)
for node in fn_expected.maker.fgraph.apply_nodes for node in fn_expected.maker.fgraph.apply_nodes
) )
assert n_dots_no_rewrite == 1 assert n_dots_no_rewrite == 1
...@@ -4939,10 +4944,9 @@ def test_local_block_diag_dot_to_dot_block_diag( ...@@ -4939,10 +4944,9 @@ def test_local_block_diag_dot_to_dot_block_diag(
rtol=1e-6 if config.floatX == "float32" else 1e-12, rtol=1e-6 if config.floatX == "float32" else 1e-12,
) )
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) @pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) def test_benchmark(self, benchmark, size, rewrite):
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
rng = np.random.default_rng() rng = np.random.default_rng()
a_size = int(rng.uniform(1, int(0.8 * size))) a_size = int(rng.uniform(1, int(0.8 * size)))
b_size = int(rng.uniform(1, int(0.8 * (size - a_size)))) b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论