提交 945e9799 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Respond to feedback

上级 2ea17e72
......@@ -70,9 +70,6 @@ MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)
def fuse_blockdiagonal(fgraph, node):
"""Fuse nested BlockDiagonal ops into a single BlockDiagonal."""
if not isinstance(node.op, BlockDiagonal):
return None
new_inputs = []
changed = False
......@@ -86,6 +83,7 @@ def fuse_blockdiagonal(fgraph, node):
if changed:
fused_op = BlockDiagonal(len(new_inputs))
new_output = fused_op(*new_inputs)
copy_stack_trace(node.outputs[0], new_output)
return [new_output]
return None
......
......@@ -10,7 +10,7 @@ from pytensor import function
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph import FunctionGraph, ancestors
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
......@@ -52,23 +52,22 @@ def test_nested_blockdiag_fusion():
inner = BlockDiagonal(2)(x, y)
outer = BlockDiagonal(2)(inner, z)
nodes_before = ancestors([outer])
initial_count = sum(
1
for node in nodes_before
for node in ancestors([outer])
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops"
f = pytensor.function([x, y, z], outer)
fgraph = f.maker.fgraph
fgraph = FunctionGraph(inputs=[x, y, z], outputs=[outer])
rewrite_graph(fgraph, include=("fast_run", "blockdiag_fusion"))
nodes_after = fgraph.apply_nodes
fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)]
fused_nodes = [
node for node in fgraph.toposort() if isinstance(node.op, BlockDiagonal)
]
assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused"
fused_op = fused_nodes[0].op
assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}"
out_shape = fgraph.outputs[0].type.shape
......@@ -85,21 +84,20 @@ def test_deeply_nested_blockdiag_fusion():
inner2 = BlockDiagonal(2)(inner1, z)
outer = BlockDiagonal(2)(inner2, w)
f = pytensor.function([x, y, z, w], outer)
fgraph = f.maker.fgraph
fgraph = FunctionGraph(inputs=[x, y, z, w], outputs=[outer])
rewrite_graph(fgraph, include=("fast_run", "blockdiag_fusion"))
fused_nodes = [
fused_block_diag_nodes = [
node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal)
]
assert len(fused_nodes) == 1, (
f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}"
assert len(fused_block_diag_nodes) == 1, (
f"Expected 1 fused BlockDiagonal, got {len(fused_block_diag_nodes)}"
)
fused_op = fused_nodes[0].op
fused_block_diag_op = fused_block_diag_nodes[0].op
assert fused_op.n_inputs == 4, (
f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}"
assert fused_block_diag_op.n_inputs == 4, (
f"Expected n_inputs=4 after fusion, got {fused_block_diag_op.n_inputs}"
)
out_shape = fgraph.outputs[0].type.shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论