提交 0bb15f9d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix xtensor Transpose with ellipsis

上级 e03605e4
...@@ -189,7 +189,7 @@ class Transpose(XOp): ...@@ -189,7 +189,7 @@ class Transpose(XOp):
def transpose( def transpose(
x, x,
*dims: str | EllipsisType, *dim: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise", missing_dims: Literal["raise", "warn", "ignore"] = "raise",
): ):
"""Transpose dimensions of the tensor. """Transpose dimensions of the tensor.
...@@ -198,7 +198,7 @@ def transpose( ...@@ -198,7 +198,7 @@ def transpose(
---------- ----------
x : XTensorVariable x : XTensorVariable
Input tensor to transpose. Input tensor to transpose.
*dims : str *dim : str
Dimensions to transpose to. Can include ellipsis (...) to represent Dimensions to transpose to. Can include ellipsis (...) to represent
remaining dimensions in their original order. remaining dimensions in their original order.
missing_dims : {"raise", "warn", "ignore"}, optional missing_dims : {"raise", "warn", "ignore"}, optional
...@@ -220,7 +220,7 @@ def transpose( ...@@ -220,7 +220,7 @@ def transpose(
# Validate dimensions # Validate dimensions
x = as_xtensor(x) x = as_xtensor(x)
x_dims = x.type.dims x_dims = x.type.dims
invalid_dims = set(dims) - {..., *x_dims} invalid_dims = set(dim) - {..., *x_dims}
if invalid_dims: if invalid_dims:
if missing_dims != "ignore": if missing_dims != "ignore":
msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {x_dims}" msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {x_dims}"
...@@ -229,21 +229,27 @@ def transpose( ...@@ -229,21 +229,27 @@ def transpose(
else: else:
warnings.warn(msg) warnings.warn(msg)
# Handle missing dimensions if not raising # Handle missing dimensions if not raising
dims = tuple(d for d in dims if d in x_dims or d is ...) dim = tuple(d for d in dim if d in x_dims or d is ...)
if dims == () or dims == (...,): if dim == ():
dims = tuple(reversed(x_dims)) dim = tuple(reversed(x_dims))
elif ... in dims: elif dim == (...,):
if dims.count(...) > 1: dim = x_dims
elif ... in dim:
if dim.count(...) > 1:
raise ValueError("Ellipsis (...) can only appear once in the dimensions") raise ValueError("Ellipsis (...) can only appear once in the dimensions")
# Handle ellipsis expansion # Handle ellipsis expansion
ellipsis_idx = dims.index(...) ellipsis_idx = dim.index(...)
pre = dims[:ellipsis_idx] pre = dim[:ellipsis_idx]
post = dims[ellipsis_idx + 1 :] post = dim[ellipsis_idx + 1 :]
middle = [d for d in x_dims if d not in pre + post] middle = [d for d in x_dims if d not in pre + post]
dims = (*pre, *middle, *post) dim = (*pre, *middle, *post)
if dim == x_dims:
# No-op transpose
return x
return Transpose(typing.cast(tuple[str], dims))(x) return Transpose(dims=typing.cast(tuple[str], dim))(x)
class Concat(XOp): class Concat(XOp):
......
...@@ -691,14 +691,14 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -691,14 +691,14 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#id8 # https://docs.xarray.dev/en/latest/api.html#id8
def transpose( def transpose(
self, self,
*dims: str | EllipsisType, *dim: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise", missing_dims: Literal["raise", "warn", "ignore"] = "raise",
): ):
"""Transpose dimensions of the tensor. """Transpose dimensions of the tensor.
Parameters Parameters
---------- ----------
*dims : str | Ellipsis *dim : str | Ellipsis
Dimensions to transpose. If empty, performs a full transpose. Dimensions to transpose. If empty, performs a full transpose.
Can use ellipsis (...) to represent remaining dimensions. Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise" missing_dims : {"raise", "warn", "ignore"}, default="raise"
...@@ -718,7 +718,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -718,7 +718,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
If missing_dims="raise" and any dimensions don't exist. If missing_dims="raise" and any dimensions don't exist.
If multiple ellipsis are provided. If multiple ellipsis are provided.
""" """
return px.shape.transpose(self, *dims, missing_dims=missing_dims) return px.shape.transpose(self, *dim, missing_dims=missing_dims)
def stack(self, dim, **dims): def stack(self, dim, **dims):
return px.shape.stack(self, dim, **dims) return px.shape.stack(self, dim, **dims)
......
...@@ -15,7 +15,6 @@ from pytensor.tensor import scalar ...@@ -15,7 +15,6 @@ from pytensor.tensor import scalar
from pytensor.xtensor.shape import ( from pytensor.xtensor.shape import (
concat, concat,
stack, stack,
transpose,
unstack, unstack,
) )
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
...@@ -46,13 +45,14 @@ def test_transpose(): ...@@ -46,13 +45,14 @@ def test_transpose():
permutations = [ permutations = [
(a, b, c, d, e), # identity (a, b, c, d, e), # identity
(e, d, c, b, a), # full tranpose (e, d, c, b, a), # full tranpose
(), # eqivalent to full transpose (), # equivalent to full transpose
(a, b, c, e, d), # swap last two dims (a, b, c, e, d), # swap last two dims
(..., d, c), # equivalent to (a, b, e, d, c) (..., d, c), # equivalent to (a, b, e, d, c)
(b, a, ..., e, d), # equivalent to (b, a, c, d, e) (b, a, ..., e, d), # equivalent to (b, a, c, d, e)
(c, a, ...), # equivalent to (c, a, b, d, e) (c, a, ...), # equivalent to (c, a, b, d, e)
(...,), # no op
] ]
outs = [transpose(x, *perm) for perm in permutations] outs = [x.transpose(*perm) for perm in permutations]
fn = xr_function([x], outs) fn = xr_function([x], outs)
x_test = xr_arange_like(x) x_test = xr_arange_like(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论