提交 ac6dc81b authored 作者: ricardoV94's avatar ricardoV94 提交者: Jesse Grabowski

Group local_block_diag_dot_to_dot_block_diag tests

上级 e75bbb2c
...@@ -4857,119 +4857,123 @@ def test_local_dot_to_mul_unspecified_length_1(): ...@@ -4857,119 +4857,123 @@ def test_local_dot_to_mul_unspecified_length_1():
) )
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) class TestBlockDiagDotToDotBlockDiag:
@pytest.mark.parametrize( @pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"] @pytest.mark.parametrize(
) "batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
@pytest.mark.parametrize( )
"batch_other", [True, False], ids=["batched_other", "unbatched_other"] @pytest.mark.parametrize(
) "batch_other", [True, False], ids=["batched_other", "unbatched_other"]
def test_local_block_diag_dot_to_dot_block_diag( )
left_multiply, batch_blockdiag, batch_other def test_rewrite_applies(self, left_multiply, batch_blockdiag, batch_other):
): """
""" Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) """
"""
def has_blockdiag(graph): def has_blockdiag(graph):
return any( return any(
( (
var.owner var.owner
and ( and (
isinstance(var.owner.op, BlockDiagonal) isinstance(var.owner.op, BlockDiagonal)
or ( or (
isinstance(var.owner.op, Blockwise) isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, BlockDiagonal) and isinstance(var.owner.op.core_op, BlockDiagonal)
)
) )
) )
for var in ancestors([graph])
) )
for var in ancestors([graph])
)
a = tensor("a", shape=(4, 2))
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
c = tensor("c", shape=(4, 4))
x = pt.linalg.block_diag(a, b, c)
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10)) a = tensor("a", shape=(4, 2))
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
c = tensor("c", shape=(4, 4))
x = pt.linalg.block_diag(a, b, c)
# Test multiple clients are all rewritten d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
if left_multiply:
out = x @ d
else:
out = d @ x
assert has_blockdiag(out) # Test multiple clients are all rewritten
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) if left_multiply:
assert not has_blockdiag(fn.maker.fgraph.outputs[0]) out = x @ d
else:
out = d @ x
n_dots_rewrite = sum( assert has_blockdiag(out)
isinstance(node.op, Dot | Dot22) fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) assert not has_blockdiag(fn.maker.fgraph.outputs[0])
for node in fn.maker.fgraph.apply_nodes
)
assert n_dots_rewrite == 3
fn_expected = pytensor.function( n_dots_rewrite = sum(
[a, b, c, d], isinstance(node.op, Dot | Dot22)
out, or (
mode=Mode(linker="py", optimizer=None), isinstance(node.op, Blockwise)
) and isinstance(node.op.core_op, Dot | Dot22)
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0]) )
for node in fn.maker.fgraph.apply_nodes
)
assert n_dots_rewrite == 3
n_dots_no_rewrite = sum( fn_expected = pytensor.function(
isinstance(node.op, Dot | Dot22) [a, b, c, d],
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) out,
for node in fn_expected.maker.fgraph.apply_nodes mode=Mode(linker="py", optimizer=None),
) )
assert n_dots_no_rewrite == 1 assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
rng = np.random.default_rng() n_dots_no_rewrite = sum(
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) isinstance(node.op, Dot | Dot22)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) or (
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) isinstance(node.op, Blockwise)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) and isinstance(node.op.core_op, Dot | Dot22)
)
for node in fn_expected.maker.fgraph.apply_nodes
)
assert n_dots_no_rewrite == 1
rewrite_out = fn(a_val, b_val, c_val, d_val) rng = np.random.default_rng()
expected_out = fn_expected(a_val, b_val, c_val, d_val) a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
np.testing.assert_allclose( b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
rewrite_out, c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
expected_out, d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
atol=1e-6 if config.floatX == "float32" else 1e-12,
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)
rewrite_out = fn(a_val, b_val, c_val, d_val)
expected_out = fn_expected(a_val, b_val, c_val, d_val)
np.testing.assert_allclose(
rewrite_out,
expected_out,
atol=1e-6 if config.floatX == "float32" else 1e-12,
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) @pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) @pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite): def test_benchmark(self, benchmark, size, rewrite):
rng = np.random.default_rng() rng = np.random.default_rng()
a_size = int(rng.uniform(1, int(0.8 * size))) a_size = int(rng.uniform(1, int(0.8 * size)))
b_size = int(rng.uniform(1, int(0.8 * (size - a_size)))) b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
c_size = size - a_size - b_size c_size = size - a_size - b_size
a = tensor("a", shape=(a_size, a_size)) a = tensor("a", shape=(a_size, a_size))
b = tensor("b", shape=(b_size, b_size)) b = tensor("b", shape=(b_size, b_size))
c = tensor("c", shape=(c_size, c_size)) c = tensor("c", shape=(c_size, c_size))
d = tensor("d", shape=(size,)) d = tensor("d", shape=(size,))
x = pt.linalg.block_diag(a, b, c) x = pt.linalg.block_diag(a, b, c)
out = x @ d out = x @ d
mode = get_default_mode() mode = get_default_mode()
if not rewrite: if not rewrite:
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag") mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
fn = pytensor.function([a, b, c, d], out, mode=mode) fn = pytensor.function([a, b, c, d], out, mode=mode)
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
benchmark( benchmark(
fn, fn,
a_val, a_val,
b_val, b_val,
c_val, c_val,
d_val, d_val,
) )
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论