提交 1d13f8c4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid runtime broadcast error due to dot_to_mul rewrite

上级 1a31bb30
......@@ -106,7 +106,7 @@ from pytensor.tensor.rewriting.basic import (
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.shape import Shape, Shape_i, specify_shape
from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
......@@ -424,6 +424,13 @@ def local_dot_to_mul(fgraph, node):
):
return None
# Add specify_shape for unknown dimensions that must be 1
# To avoid runtime broadcast error by multiply
if a.type.shape[-1] != 1:
a = specify_shape(a, (..., None, 1))
if b.type.shape[-2] != 1:
b = specify_shape(b, (..., 1, None))
new_out = mul(a, b)
copy_stack_trace(node.out, new_out)
return [new_out]
......
......@@ -4805,6 +4805,35 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
)
def test_local_dot_to_mul_unspecified_length_1():
# Regression test for https://github.com/pymc-devs/pytensor/issues/1782
x = matrix("x", shape=(5, 1), dtype="float64")
y = matrix("y", shape=(None, 1), dtype="float64")
out = x @ y
fn = function([x, y], out)
assert all(
isinstance(node.op, Elemwise | SpecifyShape)
for node in fn.maker.fgraph.apply_nodes
)
np.testing.assert_allclose(
fn(x=np.ones((5, 1)), y=np.ones((1, 1)) * 5),
np.ones((5, 1)) * 5,
)
x = matrix("x", shape=(1, None), dtype="float64")
y = matrix("y", shape=(1, 5), dtype="float64")
out = x @ y
fn = function([x, y], out)
assert all(
isinstance(node.op, Elemwise | SpecifyShape)
for node in fn.maker.fgraph.apply_nodes
)
np.testing.assert_allclose(
fn(x=np.ones((1, 1)) * 5, y=np.ones((1, 5))),
np.ones((1, 5)) * 5,
)
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论