Unverified 提交 00fea0e3 authored 作者: Abhinav's avatar Abhinav 提交者: GitHub

Fix einsum failing with repeated inputs (#1260)

* fixed Einsum failing with repeated inputs * Optimise the _ensure_not_equal function Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Fix einsum failing on repeated inputs * Fix einsum failing with repeated inputs * Added regression test for repeated inputs to the einsum function * Fix for failing test Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com> --------- Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com>
上级 c0860f86
......@@ -417,6 +417,18 @@ def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]:
return tuple(pairwise(reversed(range(n))))
def _ensure_not_equal(elements):
"""
Ensures that any pair in a list of elements are not the same object. If a pair of elements is found to be equal, then one of them is converted to a copy.
"""
elements = list(elements)
for i, elem1 in enumerate(elements[:-1]):
for j, elem2 in enumerate(elements[i + 1 :], start=i + 1):
if elem1 is elem2:
elements[j] = elem1.copy()
return elements
def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
"""
Multiplication and summation of tensors using the Einstein summation convention.
......@@ -553,7 +565,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
"If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. "
)
tensor_operands = [as_tensor(operand) for operand in operands]
tensor_operands = _ensure_not_equal([as_tensor(operand) for operand in operands])
shapes = [operand.type.shape for operand in tensor_operands]
path: PATH
......
......@@ -8,6 +8,7 @@ import pytensor
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.tensor import matrix
from pytensor.tensor.basic import moveaxis
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
......@@ -281,3 +282,15 @@ def test_threeway_mul(static_length):
out.eval({x: x_test, y: y_test, z: z_test}),
np.full((3,), fill_value=6),
)
def test_repeated_inputs():
x = matrix("x")
out_repeated = einsum("ij,ij->i", x, x)
out_copy = einsum("ij,ij->i", x, x.copy())
x_test = np.array([[1, 2], [3, 4]]).astype(x.dtype)
np.testing.assert_allclose(
out_repeated.eval({x: x_test}), out_copy.eval({x: x_test})
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论