提交 945e9799 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Respond to feedback

上级 2ea17e72
...@@ -70,9 +70,6 @@ MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) ...@@ -70,9 +70,6 @@ MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)
def fuse_blockdiagonal(fgraph, node): def fuse_blockdiagonal(fgraph, node):
"""Fuse nested BlockDiagonal ops into a single BlockDiagonal.""" """Fuse nested BlockDiagonal ops into a single BlockDiagonal."""
if not isinstance(node.op, BlockDiagonal):
return None
new_inputs = [] new_inputs = []
changed = False changed = False
...@@ -86,6 +83,7 @@ def fuse_blockdiagonal(fgraph, node): ...@@ -86,6 +83,7 @@ def fuse_blockdiagonal(fgraph, node):
if changed: if changed:
fused_op = BlockDiagonal(len(new_inputs)) fused_op = BlockDiagonal(len(new_inputs))
new_output = fused_op(*new_inputs) new_output = fused_op(*new_inputs)
copy_stack_trace(node.outputs[0], new_output)
return [new_output] return [new_output]
return None return None
......
...@@ -10,7 +10,7 @@ from pytensor import function ...@@ -10,7 +10,7 @@ from pytensor import function
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile import get_default_mode from pytensor.compile import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import ancestors from pytensor.graph import FunctionGraph, ancestors
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor import swapaxes from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
...@@ -52,23 +52,22 @@ def test_nested_blockdiag_fusion(): ...@@ -52,23 +52,22 @@ def test_nested_blockdiag_fusion():
inner = BlockDiagonal(2)(x, y) inner = BlockDiagonal(2)(x, y)
outer = BlockDiagonal(2)(inner, z) outer = BlockDiagonal(2)(inner, z)
nodes_before = ancestors([outer])
initial_count = sum( initial_count = sum(
1 1
for node in nodes_before for node in ancestors([outer])
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
) )
assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops" assert initial_count == 2, "Setup failed: expected 2 nested BlockDiagonal ops"
f = pytensor.function([x, y, z], outer) fgraph = FunctionGraph(inputs=[x, y, z], outputs=[outer])
fgraph = f.maker.fgraph rewrite_graph(fgraph, include=("fast_run", "blockdiag_fusion"))
nodes_after = fgraph.apply_nodes fused_nodes = [
fused_nodes = [node for node in nodes_after if isinstance(node.op, BlockDiagonal)] node for node in fgraph.toposort() if isinstance(node.op, BlockDiagonal)
]
assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused" assert len(fused_nodes) == 1, "Nested BlockDiagonal ops were not fused"
fused_op = fused_nodes[0].op fused_op = fused_nodes[0].op
assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}" assert fused_op.n_inputs == 3, f"Expected n_inputs=3, got {fused_op.n_inputs}"
out_shape = fgraph.outputs[0].type.shape out_shape = fgraph.outputs[0].type.shape
...@@ -85,21 +84,20 @@ def test_deeply_nested_blockdiag_fusion(): ...@@ -85,21 +84,20 @@ def test_deeply_nested_blockdiag_fusion():
inner2 = BlockDiagonal(2)(inner1, z) inner2 = BlockDiagonal(2)(inner1, z)
outer = BlockDiagonal(2)(inner2, w) outer = BlockDiagonal(2)(inner2, w)
f = pytensor.function([x, y, z, w], outer) fgraph = FunctionGraph(inputs=[x, y, z, w], outputs=[outer])
fgraph = f.maker.fgraph rewrite_graph(fgraph, include=("fast_run", "blockdiag_fusion"))
fused_nodes = [ fused_block_diag_nodes = [
node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal) node for node in fgraph.apply_nodes if isinstance(node.op, BlockDiagonal)
] ]
assert len(fused_block_diag_nodes) == 1, (
assert len(fused_nodes) == 1, ( f"Expected 1 fused BlockDiagonal, got {len(fused_block_diag_nodes)}"
f"Expected 1 fused BlockDiagonal, got {len(fused_nodes)}"
) )
fused_op = fused_nodes[0].op fused_block_diag_op = fused_block_diag_nodes[0].op
assert fused_op.n_inputs == 4, ( assert fused_block_diag_op.n_inputs == 4, (
f"Expected n_inputs=4 after fusion, got {fused_op.n_inputs}" f"Expected n_inputs=4 after fusion, got {fused_block_diag_op.n_inputs}"
) )
out_shape = fgraph.outputs[0].type.shape out_shape = fgraph.outputs[0].type.shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论