提交 54a85007 authored 作者: ricardoV94's avatar ricardoV94 提交者: Jesse Grabowski

Fix bug in local_block_diag_dot_to_dot_block_diag

上级 ac6dc81b
......@@ -183,13 +183,23 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
try:
client_idx = client.inputs.index(blockdiag_result)
except ValueError:
# If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a
# DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the
# blockdiag result.
# If the blockdiag result is not an input to the dot, there is at least one Op between them.
# We allow left expand_dims (DimShuffle), which is introduced automatically by Blockwise to equalize number of batch dims,
# But does not change the semantics of the graph
for ancestor in client.inputs:
if ancestor.owner and blockdiag_result in ancestor.owner.inputs:
if (
ancestor.owner is not None
and (
isinstance(ancestor.owner.op, DimShuffle)
and ancestor.owner.op.is_left_expand_dims
)
and blockdiag_result in ancestor.owner.inputs
):
client_idx = client.inputs.index(ancestor)
break
else: # no-break
# Not a simple left expand_dims between dot and block_diag
return None
other_input = client.inputs[1 - client_idx]
......
......@@ -148,6 +148,7 @@ from pytensor.tensor.type import (
)
from pytensor.tensor.variable import TensorConstant
from tests import unittest_tools as utt
from tests.unittest_tools import assert_equal_computations
rewrite_mode = config.mode
......@@ -4944,6 +4945,29 @@ class TestBlockDiagDotToDotBlockDiag:
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)
def test_rewrite_does_not_apply(self):
# Regression test for https://github.com/pymc-devs/pytensor/issues/1836
# Shapes match if either R is tranposed or y is, but not by default
y = pt.tensor("y", shape=(7, 9))
R1 = pt.tensor("R1", shape=(2, 3))
R2 = pt.tensor("R2", shape=(5, 6))
R = pt.linalg.block_diag(R1, R2)
# This could be rewritten in the future, if that's the case remove this condition
original = dot(R.mT, y)
rewritten = rewrite_graph(
original, include=("canonicalize", "stabilize", "specialize")
)
assert_equal_computations([rewritten], [original])
# This is unlikely to ever be rewritten
original = dot(R.exp(), y.mT)
rewritten = rewrite_graph(
original, include=("canonicalize", "stabilize", "specialize")
)
assert_equal_computations([rewritten], [original])
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
def test_benchmark(self, benchmark, size, rewrite):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论