提交 2ea17e72 authored 作者: Eby Elanjikal's avatar Eby Elanjikal 提交者: Ricardo Vieira

linalg: fuse nested BlockDiagonal ops and add corresponding tests

上级 faa4175b
...@@ -60,24 +60,23 @@ from pytensor.tensor.slinalg import ( ...@@ -60,24 +60,23 @@ from pytensor.tensor.slinalg import (
solve_triangular, solve_triangular,
) )
from pytensor.tensor.slinalg import BlockDiagonal
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)
from pytensor.tensor.slinalg import BlockDiagonal @register_canonicalize
from pytensor.graph import Apply @node_rewriter([BlockDiagonal])
def fuse_blockdiagonal(fgraph, node):
"""Fuse nested BlockDiagonal ops into a single BlockDiagonal."""
def fuse_blockdiagonal(node): if not isinstance(node.op, BlockDiagonal):
# Only process if this node is a BlockDiagonal return None
if not isinstance(node.owner.op, BlockDiagonal):
return node
new_inputs = [] new_inputs = []
changed = False changed = False
for inp in node.owner.inputs:
# If input is itself a BlockDiagonal, flatten its inputs for inp in node.inputs:
if inp.owner and isinstance(inp.owner.op, BlockDiagonal): if inp.owner and isinstance(inp.owner.op, BlockDiagonal):
new_inputs.extend(inp.owner.inputs) new_inputs.extend(inp.owner.inputs)
changed = True changed = True
...@@ -85,9 +84,11 @@ def fuse_blockdiagonal(node): ...@@ -85,9 +84,11 @@ def fuse_blockdiagonal(node):
new_inputs.append(inp) new_inputs.append(inp)
if changed: if changed:
# Return a new fused BlockDiagonal with all inputs fused_op = BlockDiagonal(len(new_inputs))
return BlockDiagonal(len(new_inputs))(*new_inputs) new_output = fused_op(*new_inputs)
return node return [new_output]
return None
def is_matrix_transpose(x: TensorVariable) -> bool: def is_matrix_transpose(x: TensorVariable) -> bool:
......
...@@ -43,50 +43,72 @@ from pytensor.tensor.type import dmatrix, matrix, tensor, vector ...@@ -43,50 +43,72 @@ from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.test_rop import break_op from tests.test_rop import break_op
from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal
def test_nested_blockdiag_fusion(): def test_nested_blockdiag_fusion():
# Create matrix variables x = pt.tensor("x", shape=(3, 3))
x = pt.matrix("x") y = pt.tensor("y", shape=(3, 3))
y = pt.matrix("y") z = pt.tensor("z", shape=(3, 3))
z = pt.matrix("z")
# Nested BlockDiagonal inner = BlockDiagonal(2)(x, y)
inner = BlockDiagonal(2)(x, y)
outer = BlockDiagonal(2)(inner, z) outer = BlockDiagonal(2)(inner, z)
# Count number of BlockDiagonal ops before fusion
nodes_before = ancestors([outer]) nodes_before = ancestors([outer])
initial_count = sum( initial_count = sum(
1 for node in nodes_before 1
for node in nodes_before
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
) )
assert initial_count > 1, "Setup failed: should have nested BlockDiagonal" assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops"
# Apply the rewrite f = pytensor.function([x, y, z], outer)
fused = fuse_blockdiagonal(outer) fgraph = f.maker.fgraph
# Count number of BlockDiagonal ops after fusion nodes_after = fgraph.apply_nodes
nodes_after = ancestors([fused]) fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)]
fused_count = sum( assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused"
1 for node in nodes_after
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert fused_count == 1, "Nested BlockDiagonal ops were not fused"
# Check that all original inputs are preserved fused_op = fused_nodes[0].op
fused_inputs = [
inp assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}"
for node in ancestors([fused])
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) out_shape = fgraph.outputs[0].type.shape
for inp in node.owner.inputs assert out_shape == (9, 9), f"Unexpected fused output shape: {out_shape}"
def test_deeply_nested_blockdiag_fusion():
x = pt.tensor("x", shape=(3, 3))
y = pt.tensor("y", shape=(3, 3))
z = pt.tensor("z", shape=(3, 3))
w = pt.tensor("w", shape=(3, 3))
inner1 = BlockDiagonal(2)(x, y)
inner2 = BlockDiagonal(2)(inner1, z)
outer = BlockDiagonal(2)(inner2, w)
f = pytensor.function([x, y, z, w], outer)
fgraph = f.maker.fgraph
fused_nodes = [
node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal)
] ]
assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused"
assert len(fused_nodes) == 1, (
f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}"
)
fused_op = fused_nodes[0].op
assert fused_op.n_inputs == 4, (
f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}"
)
out_shape = fgraph.outputs[0].type.shape
expected_shape = (12, 12) # 4 blocks of (3x3)
assert out_shape == expected_shape, (
f"Unexpected fused output shape: expected {expected_shape}, got {out_shape}"
)
def test_matrix_inverse_rop_lop(): def test_matrix_inverse_rop_lop():
rtol = 1e-7 if config.floatX == "float64" else 1e-5 rtol = 1e-7 if config.floatX == "float64" else 1e-5
mx = matrix("mx") mx = matrix("mx")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论