提交 c22e79e1 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix bug in einsum

A shortcut in the numpy implementation of einsum_path when there's nothing to optimize, creates a default path that can combine more than 2 operands. Our implementation only works with 2 or 1 operand operations at each step. https://github.com/numpy/numpy/blob/cc5851e654bfd82a23f2758be4bd224be84fc1c3/numpy/_core/einsumfunc.py#L945-L951
上级 8bb2038d
...@@ -410,6 +410,12 @@ def _contraction_list_from_path( ...@@ -410,6 +410,12 @@ def _contraction_list_from_path(
return contraction_list return contraction_list
def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]:
# Create a right to left contraction path
# if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
return tuple(pairwise(reversed(range(n))))
def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable: def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
""" """
Multiplication and summation of tensors using the Einstein summation convention. Multiplication and summation of tensors using the Einstein summation convention.
...@@ -563,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar ...@@ -563,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
else: else:
# By default, we try right to left because we assume that most graphs # By default, we try right to left because we assume that most graphs
# have a lower dimensional rightmost operand # have a lower dimensional rightmost operand
path = tuple(pairwise(reversed(range(len(tensor_operands))))) path = _right_to_left_path(len(tensor_operands))
contraction_list = _contraction_list_from_path( contraction_list = _contraction_list_from_path(
subscripts, tensor_operands, path subscripts, tensor_operands, path
) )
...@@ -581,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar ...@@ -581,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
einsum_call=True, # Not part of public API einsum_call=True, # Not part of public API
optimize="optimal", optimize="optimal",
) # type: ignore ) # type: ignore
path = tuple(contraction[0] for contraction in contraction_list) np_path = tuple(contraction[0] for contraction in contraction_list)
if len(np_path) == 1 and len(np_path[0]) > 2:
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
# pairwise reductions, which our implementation below demands.
path = _right_to_left_path(len(tensor_operands))
contraction_list = _contraction_list_from_path(
subscripts, tensor_operands, path
)
else:
path = np_path
optimized = True optimized = True
def removechars(s, chars): def removechars(s, chars):
...@@ -744,7 +761,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar ...@@ -744,7 +761,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
) )
else: else:
raise ValueError( raise ValueError(
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}" f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}, {path=}."
) )
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result # the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
......
...@@ -262,3 +262,22 @@ def test_broadcastable_dims(): ...@@ -262,3 +262,22 @@ def test_broadcastable_dims():
atol = 1e-12 if config.floatX == "float64" else 1e-2 atol = 1e-12 if config.floatX == "float64" else 1e-2
np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol) np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol)
np.testing.assert_allclose(optimal_eval, np_eval, atol=atol) np.testing.assert_allclose(optimal_eval, np_eval, atol=atol)
@pytest.mark.parametrize("static_length", [False, True])
def test_threeway_mul(static_length):
# Regression test for https://github.com/pymc-devs/pytensor/issues/1184
# x, y, z = vectors("x", "y", "z")
sh = (3,) if static_length else (None,)
x = tensor("x", shape=sh)
y = tensor("y", shape=sh)
z = tensor("z", shape=sh)
out = einsum("..., ..., ... -> ...", x, y, z)
x_test = np.ones((3,), dtype=x.dtype)
y_test = x_test + 1
z_test = x_test + 2
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test, z: z_test}),
np.full((3,), fill_value=6),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论