提交 4765cdad authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Einsum: Compatibility with newer numpy

上级 d95fa25a
...@@ -32,6 +32,7 @@ from pytensor.tensor.variable import TensorVariable ...@@ -32,6 +32,7 @@ from pytensor.tensor.variable import TensorVariable
PATH = tuple[tuple[int] | tuple[int, int], ...] PATH = tuple[tuple[int] | tuple[int, int], ...]
CONTRACTION_STEP = tuple[tuple[int, ...], set[str], str]
class Einsum(OpFromGraph): class Einsum(OpFromGraph):
...@@ -329,7 +330,7 @@ def _general_dot( ...@@ -329,7 +330,7 @@ def _general_dot(
def _contraction_list_from_path( def _contraction_list_from_path(
subscripts: str, operands: Sequence[TensorVariable], path: PATH subscripts: str, operands: Sequence[TensorVariable], path: PATH
): ) -> list[CONTRACTION_STEP]:
""" """
Generate a list of contraction steps based on the provided einsum path. Generate a list of contraction steps based on the provided einsum path.
...@@ -361,11 +362,6 @@ def _contraction_list_from_path( ...@@ -361,11 +362,6 @@ def _contraction_list_from_path(
The indices of the contracted indices (those removed from the einsum string at this step) The indices of the contracted indices (those removed from the einsum string at this step)
- einsum_str: str - einsum_str: str
The einsum string for the contraction step The einsum string for the contraction step
- remaining: None
The remaining indices. Included to match the output of opt_einsum.contract_path, but not used.
- do_blas: None
Whether to use blas to perform this step. Included to match the output of opt_einsum.contract_path,
but not used.
""" """
fake_operands = [ fake_operands = [
np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands
...@@ -379,8 +375,8 @@ def _contraction_list_from_path( ...@@ -379,8 +375,8 @@ def _contraction_list_from_path(
input_sets = [set(x) for x in input_list] input_sets = [set(x) for x in input_list]
output_set = set(output_subscript) output_set = set(output_subscript)
# Build contraction tuple (positions, gemm, einsum_str, remaining) # Build contraction tuple (positions, removed_idx, step_einsum_str)
contraction_list = [] contraction_list: list[CONTRACTION_STEP] = []
for cnum, contract_inds in enumerate(path): for cnum, contract_inds in enumerate(path):
# Make sure we remove inds from right to left # Make sure we remove inds from right to left
contract_inds = cast( contract_inds = cast(
...@@ -404,7 +400,7 @@ def _contraction_list_from_path( ...@@ -404,7 +400,7 @@ def _contraction_list_from_path(
einsum_str = ",".join(tmp_inputs) + "->" + idx_result einsum_str = ",".join(tmp_inputs) + "->" + idx_result
# We only need the first three inputs to build the forward graph # We only need the first three inputs to build the forward graph
contraction = (contract_inds, idx_removed, einsum_str, None, None) contraction = (contract_inds, idx_removed, einsum_str)
contraction_list.append(contraction) contraction_list.append(contraction)
return contraction_list return contraction_list
...@@ -568,6 +564,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar ...@@ -568,6 +564,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
shapes = [operand.type.shape for operand in tensor_operands] shapes = [operand.type.shape for operand in tensor_operands]
path: PATH path: PATH
contraction_list: list[CONTRACTION_STEP]
if any(None in shape for shape in shapes): if any(None in shape for shape in shapes):
# Case 1: At least one of the operands has an unknown shape. In this case, we can't use opt_einsum to optimize # Case 1: At least one of the operands has an unknown shape. In this case, we can't use opt_einsum to optimize
# the contraction order, so we just use a default path of (1,0) contractions. This will work left-to-right, # the contraction order, so we just use a default path of (1,0) contractions. This will work left-to-right,
...@@ -591,7 +588,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar ...@@ -591,7 +588,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
else: else:
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal # Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
# contraction order. # contraction order.
_, contraction_list = np.einsum_path( _, contraction_list_raw = np.einsum_path(
subscripts, subscripts,
# Numpy einsum_path requires arrays even though only the shapes matter # Numpy einsum_path requires arrays even though only the shapes matter
# It's not trivial to duck-type our way around because of internal call to `asanyarray` # It's not trivial to duck-type our way around because of internal call to `asanyarray`
...@@ -600,6 +597,26 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar ...@@ -600,6 +597,26 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
einsum_call=True, # type: ignore[arg-type] einsum_call=True, # type: ignore[arg-type]
optimize="optimal", optimize="optimal",
) )
# Numpy API changed in v2.4.2, and now returns only 3 values instead of 5
# We never needed the last two but we need the code to work with both cases
contraction_list = []
if contraction_list_raw:
match len(contraction_list_raw[0]):
case 5:
# Old API, the first 3 entries have what we need
contraction_list = [c[:3] for c in contraction_list_raw] # type: ignore[misc]
case 3:
# New API doesn't have index removed
contraction_list = []
for pos, step_ein_str, _ in contraction_list_raw: # type: ignore[misc]
# e.g., 'ijp,oij->op' -> removed_str = {'i', 'j'}
inp_str, out_str = step_ein_str.replace(",", "").split("->") # type: ignore[has-type]
removed_idx = set(inp_str) - set(out_str)
contraction_list.append((pos, removed_idx, step_ein_str)) # type: ignore[has-type]
case _:
raise ValueError("Unexpected contraction list template")
del contraction_list_raw
np_path: PATH | tuple[tuple[int, ...]] = tuple( np_path: PATH | tuple[tuple[int, ...]] = tuple(
contraction[0] # type: ignore[misc] contraction[0] # type: ignore[misc]
for contraction in contraction_list for contraction in contraction_list
...@@ -669,12 +686,8 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar ...@@ -669,12 +686,8 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
return operand.squeeze(squeeze_axes), "".join(names[i] for i in keep_axes) return operand.squeeze(squeeze_axes), "".join(names[i] for i in keep_axes)
einsum_operands = list(tensor_operands) # So we can pop einsum_operands = list(tensor_operands) # So we can pop
for operand_indices, contracted_names, einstr, _, _ in contraction_list: for operand_indices, contracted_names_set, einstr in contraction_list:
contracted_names = sorted(contracted_names) contracted_names = sorted(contracted_names_set)
assert len(contracted_names) == len(set(contracted_names)), (
"The set was needed!"
)
input_str, result_names = einstr.split("->") input_str, result_names = einstr.split("->")
input_names = input_str.split(",") input_names = input_str.split(",")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论