提交 faa4175b authored 作者: Eby Elanjikal's avatar Eby Elanjikal 提交者: Ricardo Vieira

Add fuse_blockdiagonal rewrite and corresponding test for nested BlockDiagonal

上级 ac11da62
......@@ -60,11 +60,36 @@ from pytensor.tensor.slinalg import (
solve_triangular,
)
from pytensor.tensor.slinalg import BlockDiagonal
logger = logging.getLogger(__name__)
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)
from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.graph import Apply
def fuse_blockdiagonal(node):
# Only process if this node is a BlockDiagonal
if not isinstance(node.owner.op, BlockDiagonal):
return node
new_inputs = []
changed = False
for inp in node.owner.inputs:
# If input is itself a BlockDiagonal, flatten its inputs
if inp.owner and isinstance(inp.owner.op, BlockDiagonal):
new_inputs.extend(inp.owner.inputs)
changed = True
else:
new_inputs.append(inp)
if changed:
# Return a new fused BlockDiagonal with all inputs
return BlockDiagonal(len(new_inputs))(*new_inputs)
return node
def is_matrix_transpose(x: TensorVariable) -> bool:
"""Check if a variable corresponds to a transpose of the last two axes"""
node = x.owner
......
......@@ -43,7 +43,50 @@ from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal
def test_nested_blockdiag_fusion():
# Create matrix variables
x = pt.matrix("x")
y = pt.matrix("y")
z = pt.matrix("z")
# Nested BlockDiagonal
inner = BlockDiagonal(2)(x, y)
outer = BlockDiagonal(2)(inner, z)
# Count number of BlockDiagonal ops before fusion
nodes_before = ancestors([outer])
initial_count = sum(
1 for node in nodes_before
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert initial_count > 1, "Setup failed: should have nested BlockDiagonal"
# Apply the rewrite
fused = fuse_blockdiagonal(outer)
# Count number of BlockDiagonal ops after fusion
nodes_after = ancestors([fused])
fused_count = sum(
1 for node in nodes_after
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert fused_count == 1, "Nested BlockDiagonal ops were not fused"
# Check that all original inputs are preserved
fused_inputs = [
inp
for node in ancestors([fused])
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
for inp in node.owner.inputs
]
assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused"
def test_matrix_inverse_rop_lop():
rtol = 1e-7 if config.floatX == "float64" else 1e-5
mx = matrix("mx")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论