提交 0bb15f9d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix xtensor Transpose with ellipsis

上级 e03605e4
......@@ -189,7 +189,7 @@ class Transpose(XOp):
def transpose(
x,
*dims: str | EllipsisType,
*dim: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
):
"""Transpose dimensions of the tensor.
......@@ -198,7 +198,7 @@ def transpose(
----------
x : XTensorVariable
Input tensor to transpose.
*dims : str
*dim : str
Dimensions to transpose to. Can include ellipsis (...) to represent
remaining dimensions in their original order.
missing_dims : {"raise", "warn", "ignore"}, optional
......@@ -220,7 +220,7 @@ def transpose(
# Validate dimensions
x = as_xtensor(x)
x_dims = x.type.dims
invalid_dims = set(dims) - {..., *x_dims}
invalid_dims = set(dim) - {..., *x_dims}
if invalid_dims:
if missing_dims != "ignore":
msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {x_dims}"
......@@ -229,21 +229,27 @@ def transpose(
else:
warnings.warn(msg)
# Handle missing dimensions if not raising
dims = tuple(d for d in dims if d in x_dims or d is ...)
if dims == () or dims == (...,):
dims = tuple(reversed(x_dims))
elif ... in dims:
if dims.count(...) > 1:
dim = tuple(d for d in dim if d in x_dims or d is ...)
if dim == ():
dim = tuple(reversed(x_dims))
elif dim == (...,):
dim = x_dims
elif ... in dim:
if dim.count(...) > 1:
raise ValueError("Ellipsis (...) can only appear once in the dimensions")
# Handle ellipsis expansion
ellipsis_idx = dims.index(...)
pre = dims[:ellipsis_idx]
post = dims[ellipsis_idx + 1 :]
ellipsis_idx = dim.index(...)
pre = dim[:ellipsis_idx]
post = dim[ellipsis_idx + 1 :]
middle = [d for d in x_dims if d not in pre + post]
dims = (*pre, *middle, *post)
dim = (*pre, *middle, *post)
if dim == x_dims:
# No-op transpose
return x
return Transpose(typing.cast(tuple[str], dims))(x)
return Transpose(dims=typing.cast(tuple[str], dim))(x)
class Concat(XOp):
......
......@@ -691,14 +691,14 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#id8
def transpose(
self,
*dims: str | EllipsisType,
*dim: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
):
"""Transpose dimensions of the tensor.
Parameters
----------
*dims : str | Ellipsis
*dim : str | Ellipsis
Dimensions to transpose. If empty, performs a full transpose.
Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise"
......@@ -718,7 +718,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
If missing_dims="raise" and any dimensions don't exist.
If multiple ellipsis are provided.
"""
return px.shape.transpose(self, *dims, missing_dims=missing_dims)
return px.shape.transpose(self, *dim, missing_dims=missing_dims)
def stack(self, dim, **dims):
return px.shape.stack(self, dim, **dims)
......
......@@ -15,7 +15,6 @@ from pytensor.tensor import scalar
from pytensor.xtensor.shape import (
concat,
stack,
transpose,
unstack,
)
from pytensor.xtensor.type import xtensor
......@@ -46,13 +45,14 @@ def test_transpose():
permutations = [
(a, b, c, d, e), # identity
(e, d, c, b, a), # full tranpose
(), # eqivalent to full transpose
(), # equivalent to full transpose
(a, b, c, e, d), # swap last two dims
(..., d, c), # equivalent to (a, b, e, d, c)
(b, a, ..., e, d), # equivalent to (b, a, c, d, e)
(c, a, ...), # equivalent to (c, a, b, d, e)
(...,), # no op
]
outs = [transpose(x, *perm) for perm in permutations]
outs = [x.transpose(*perm) for perm in permutations]
fn = xr_function([x], outs)
x_test = xr_arange_like(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论