提交 c22e79e1 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix bug in einsum

A shortcut in the numpy implementation of einsum_path when there's nothing to optimize, creates a default path that can combine more than 2 operands. Our implementation only works with 2 or 1 operand operations at each step. https://github.com/numpy/numpy/blob/cc5851e654bfd82a23f2758be4bd224be84fc1c3/numpy/_core/einsumfunc.py#L945-L951
上级 8bb2038d
......@@ -410,6 +410,12 @@ def _contraction_list_from_path(
return contraction_list
def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]:
# Create a right to left contraction path
# if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
return tuple(pairwise(reversed(range(n))))
def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
"""
Multiplication and summation of tensors using the Einstein summation convention.
......@@ -563,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
else:
# By default, we try right to left because we assume that most graphs
# have a lower dimensional rightmost operand
path = tuple(pairwise(reversed(range(len(tensor_operands)))))
path = _right_to_left_path(len(tensor_operands))
contraction_list = _contraction_list_from_path(
subscripts, tensor_operands, path
)
......@@ -581,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
einsum_call=True, # Not part of public API
optimize="optimal",
) # type: ignore
path = tuple(contraction[0] for contraction in contraction_list)
np_path = tuple(contraction[0] for contraction in contraction_list)
if len(np_path) == 1 and len(np_path[0]) > 2:
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
# pairwise reductions, which our implementation below demands.
path = _right_to_left_path(len(tensor_operands))
contraction_list = _contraction_list_from_path(
subscripts, tensor_operands, path
)
else:
path = np_path
optimized = True
def removechars(s, chars):
......@@ -744,7 +761,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
)
else:
raise ValueError(
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}"
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}, {path=}."
)
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
......
......@@ -262,3 +262,22 @@ def test_broadcastable_dims():
atol = 1e-12 if config.floatX == "float64" else 1e-2
np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol)
np.testing.assert_allclose(optimal_eval, np_eval, atol=atol)
@pytest.mark.parametrize("static_length", [False, True])
def test_threeway_mul(static_length):
# Regression test for https://github.com/pymc-devs/pytensor/issues/1184
# x, y, z = vectors("x", "y", "z")
sh = (3,) if static_length else (None,)
x = tensor("x", shape=sh)
y = tensor("y", shape=sh)
z = tensor("z", shape=sh)
out = einsum("..., ..., ... -> ...", x, y, z)
x_test = np.ones((3,), dtype=x.dtype)
y_test = x_test + 1
z_test = x_test + 2
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test, z: z_test}),
np.full((3,), fill_value=6),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论