提交 ac6dc81b authored 作者: ricardoV94's avatar ricardoV94 提交者: Jesse Grabowski

Group local_block_diag_dot_to_dot_block_diag tests

上级 e75bbb2c
......@@ -4857,119 +4857,123 @@ def test_local_dot_to_mul_unspecified_length_1():
)
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
)
@pytest.mark.parametrize(
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
)
def test_local_block_diag_dot_to_dot_block_diag(
left_multiply, batch_blockdiag, batch_other
):
"""
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
"""
class TestBlockDiagDotToDotBlockDiag:
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
)
@pytest.mark.parametrize(
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
)
def test_rewrite_applies(self, left_multiply, batch_blockdiag, batch_other):
"""
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
"""
def has_blockdiag(graph):
return any(
(
var.owner
and (
isinstance(var.owner.op, BlockDiagonal)
or (
isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, BlockDiagonal)
def has_blockdiag(graph):
return any(
(
var.owner
and (
isinstance(var.owner.op, BlockDiagonal)
or (
isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, BlockDiagonal)
)
)
)
for var in ancestors([graph])
)
for var in ancestors([graph])
)
a = tensor("a", shape=(4, 2))
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
c = tensor("c", shape=(4, 4))
x = pt.linalg.block_diag(a, b, c)
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
a = tensor("a", shape=(4, 2))
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
c = tensor("c", shape=(4, 4))
x = pt.linalg.block_diag(a, b, c)
# Test multiple clients are all rewritten
if left_multiply:
out = x @ d
else:
out = d @ x
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
assert has_blockdiag(out)
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
# Test multiple clients are all rewritten
if left_multiply:
out = x @ d
else:
out = d @ x
n_dots_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn.maker.fgraph.apply_nodes
)
assert n_dots_rewrite == 3
assert has_blockdiag(out)
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
fn_expected = pytensor.function(
[a, b, c, d],
out,
mode=Mode(linker="py", optimizer=None),
)
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
n_dots_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Dot | Dot22)
)
for node in fn.maker.fgraph.apply_nodes
)
assert n_dots_rewrite == 3
n_dots_no_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn_expected.maker.fgraph.apply_nodes
)
assert n_dots_no_rewrite == 1
fn_expected = pytensor.function(
[a, b, c, d],
out,
mode=Mode(linker="py", optimizer=None),
)
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
rng = np.random.default_rng()
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
n_dots_no_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Dot | Dot22)
)
for node in fn_expected.maker.fgraph.apply_nodes
)
assert n_dots_no_rewrite == 1
rewrite_out = fn(a_val, b_val, c_val, d_val)
expected_out = fn_expected(a_val, b_val, c_val, d_val)
np.testing.assert_allclose(
rewrite_out,
expected_out,
atol=1e-6 if config.floatX == "float32" else 1e-12,
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)
rng = np.random.default_rng()
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
rewrite_out = fn(a_val, b_val, c_val, d_val)
expected_out = fn_expected(a_val, b_val, c_val, d_val)
np.testing.assert_allclose(
rewrite_out,
expected_out,
atol=1e-6 if config.floatX == "float32" else 1e-12,
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
rng = np.random.default_rng()
a_size = int(rng.uniform(1, int(0.8 * size)))
b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
c_size = size - a_size - b_size
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
def test_benchmark(self, benchmark, size, rewrite):
rng = np.random.default_rng()
a_size = int(rng.uniform(1, int(0.8 * size)))
b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
c_size = size - a_size - b_size
a = tensor("a", shape=(a_size, a_size))
b = tensor("b", shape=(b_size, b_size))
c = tensor("c", shape=(c_size, c_size))
d = tensor("d", shape=(size,))
a = tensor("a", shape=(a_size, a_size))
b = tensor("b", shape=(b_size, b_size))
c = tensor("c", shape=(c_size, c_size))
d = tensor("d", shape=(size,))
x = pt.linalg.block_diag(a, b, c)
out = x @ d
x = pt.linalg.block_diag(a, b, c)
out = x @ d
mode = get_default_mode()
if not rewrite:
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
fn = pytensor.function([a, b, c, d], out, mode=mode)
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
benchmark(
fn,
a_val,
b_val,
c_val,
d_val,
)
mode = get_default_mode()
if not rewrite:
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
fn = pytensor.function([a, b, c, d], out, mode=mode)
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
benchmark(
fn,
a_val,
b_val,
c_val,
d_val,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论