提交 2ea17e72 authored 作者: Eby Elanjikal's avatar Eby Elanjikal 提交者: Ricardo Vieira

linalg: fuse nested BlockDiagonal ops and add corresponding tests

上级 faa4175b
......@@ -60,24 +60,23 @@ from pytensor.tensor.slinalg import (
solve_triangular,
)
from pytensor.tensor.slinalg import BlockDiagonal
logger = logging.getLogger(__name__)
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)
from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.graph import Apply
@register_canonicalize
@node_rewriter([BlockDiagonal])
def fuse_blockdiagonal(fgraph, node):
"""Fuse nested BlockDiagonal ops into a single BlockDiagonal."""
def fuse_blockdiagonal(node):
# Only process if this node is a BlockDiagonal
if not isinstance(node.owner.op, BlockDiagonal):
return node
if not isinstance(node.op, BlockDiagonal):
return None
new_inputs = []
changed = False
for inp in node.owner.inputs:
# If input is itself a BlockDiagonal, flatten its inputs
for inp in node.inputs:
if inp.owner and isinstance(inp.owner.op, BlockDiagonal):
new_inputs.extend(inp.owner.inputs)
changed = True
......@@ -85,9 +84,11 @@ def fuse_blockdiagonal(node):
new_inputs.append(inp)
if changed:
# Return a new fused BlockDiagonal with all inputs
return BlockDiagonal(len(new_inputs))(*new_inputs)
return node
fused_op = BlockDiagonal(len(new_inputs))
new_output = fused_op(*new_inputs)
return [new_output]
return None
def is_matrix_transpose(x: TensorVariable) -> bool:
......
......@@ -43,48 +43,70 @@ from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal
def test_nested_blockdiag_fusion():
# Create matrix variables
x = pt.matrix("x")
y = pt.matrix("y")
z = pt.matrix("z")
x = pt.tensor("x", shape=(3, 3))
y = pt.tensor("y", shape=(3, 3))
z = pt.tensor("z", shape=(3, 3))
# Nested BlockDiagonal
inner = BlockDiagonal(2)(x, y)
outer = BlockDiagonal(2)(inner, z)
# Count number of BlockDiagonal ops before fusion
nodes_before = ancestors([outer])
initial_count = sum(
1 for node in nodes_before
1
for node in nodes_before
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert initial_count > 1, "Setup failed: should have nested BlockDiagonal"
assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops"
# Apply the rewrite
fused = fuse_blockdiagonal(outer)
f = pytensor.function([x, y, z], outer)
fgraph = f.maker.fgraph
# Count number of BlockDiagonal ops after fusion
nodes_after = ancestors([fused])
fused_count = sum(
1 for node in nodes_after
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert fused_count == 1, "Nested BlockDiagonal ops were not fused"
nodes_after = fgraph.apply_nodes
fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)]
assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused"
# Check that all original inputs are preserved
fused_inputs = [
inp
for node in ancestors([fused])
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
for inp in node.owner.inputs
fused_op = fused_nodes[0].op
assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}"
out_shape = fgraph.outputs[0].type.shape
assert out_shape == (9, 9), f"Unexpected fused output shape: {out_shape}"
def test_deeply_nested_blockdiag_fusion():
x = pt.tensor("x", shape=(3, 3))
y = pt.tensor("y", shape=(3, 3))
z = pt.tensor("z", shape=(3, 3))
w = pt.tensor("w", shape=(3, 3))
inner1 = BlockDiagonal(2)(x, y)
inner2 = BlockDiagonal(2)(inner1, z)
outer = BlockDiagonal(2)(inner2, w)
f = pytensor.function([x, y, z, w], outer)
fgraph = f.maker.fgraph
fused_nodes = [
node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal)
]
assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused"
assert len(fused_nodes) == 1, (
f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}"
)
fused_op = fused_nodes[0].op
assert fused_op.n_inputs == 4, (
f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}"
)
out_shape = fgraph.outputs[0].type.shape
expected_shape = (12, 12) # 4 blocks of (3x3)
assert out_shape == expected_shape, (
f"Unexpected fused output shape: expected {expected_shape}, got {out_shape}"
)
def test_matrix_inverse_rop_lop():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论