提交 133ec80e authored 作者: Allen Downey's avatar Allen Downey 提交者: Ricardo Vieira

Implement transpose for XTensorVariables

上级 30b50fda
......@@ -2,7 +2,7 @@ from pytensor.graph import node_rewriter
from pytensor.tensor import broadcast_to, join, moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import Concat, Stack
from pytensor.xtensor.shape import Concat, Stack, Transpose
@register_lower_xtensor
......@@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]
@register_lower_xtensor
@node_rewriter(tracks=[Transpose])
def lower_transpose(fgraph, node):
[x] = node.inputs
# Use the final dimensions that were already computed in make_node
out_dims = node.outputs[0].type.dims
in_dims = x.type.dims
# Compute the permutation based on the final dimensions
perm = tuple(in_dims.index(d) for d in out_dims)
x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
return [new_out]
import typing
import warnings
from collections.abc import Sequence
from types import EllipsisType
from typing import Literal
from pytensor.graph import Apply
from pytensor.scalar import upcast
......@@ -72,6 +76,92 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
return y
class Transpose(XOp):
__props__ = ("dims",)
def __init__(
self,
dims: Sequence[str],
):
super().__init__()
self.dims = tuple(dims)
def make_node(self, x):
x = as_xtensor(x)
transpose_dims = self.dims
x_shape = x.type.shape
x_dims = x.type.dims
if set(transpose_dims) != set(x_dims):
raise ValueError(f"{transpose_dims} must be a permuted list of {x_dims}")
output = xtensor(
dtype=x.type.dtype,
shape=tuple(x_shape[x_dims.index(d)] for d in transpose_dims),
dims=transpose_dims,
)
return Apply(self, [x], [output])
def transpose(
x,
*dims: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
):
"""Transpose dimensions of the tensor.
Parameters
----------
x : XTensorVariable
Input tensor to transpose.
*dims : str
Dimensions to transpose to. Can include ellipsis (...) to represent
remaining dimensions in their original order.
missing_dims : {"raise", "warn", "ignore"}, optional
How to handle dimensions that don't exist in the input tensor:
- "raise": Raise an error if any dimensions don't exist (default)
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist
Returns
-------
XTensorVariable
Transposed tensor with reordered dimensions.
Raises
------
ValueError
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
"""
# Validate dimensions
x = as_xtensor(x)
x_dims = x.type.dims
invalid_dims = set(dims) - {..., *x_dims}
if invalid_dims:
if missing_dims != "ignore":
msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {x_dims}"
if missing_dims == "raise":
raise ValueError(msg)
else:
warnings.warn(msg)
# Handle missing dimensions if not raising
dims = tuple(d for d in dims if d in x_dims or d is ...)
if dims == () or dims == (...,):
dims = tuple(reversed(x_dims))
elif ... in dims:
if dims.count(...) > 1:
raise ValueError("Ellipsis (...) can only appear once in the dimensions")
# Handle ellipsis expansion
ellipsis_idx = dims.index(...)
pre = dims[:ellipsis_idx]
post = dims[ellipsis_idx + 1 :]
middle = [d for d in x_dims if d not in pre + post]
dims = (*pre, *middle, *post)
return Transpose(typing.cast(tuple[str], dims))(x)
class Concat(XOp):
__props__ = ("dim",)
......
import typing
from types import EllipsisType
from pytensor.compile import (
DeepCopyOp,
......@@ -23,7 +24,7 @@ except ModuleNotFoundError:
XARRAY_AVAILABLE = False
from collections.abc import Sequence
from typing import TypeVar
from typing import Literal, TypeVar
import numpy as np
......@@ -438,6 +439,19 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def real(self):
return px.math.real(self)
@property
def T(self):
"""Return the full transpose of the tensor.
This is equivalent to calling transpose() with no arguments.
Returns
-------
XTensorVariable
Fully transposed tensor.
"""
return self.transpose()
# Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6
def all(self, dim=None):
......@@ -475,6 +489,37 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# Reshaping and reorganizing
# https://docs.xarray.dev/en/latest/api.html#id8
def transpose(
self,
*dims: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
):
"""Transpose dimensions of the tensor.
Parameters
----------
*dims : str | Ellipsis
Dimensions to transpose. If empty, performs a full transpose.
Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise"
How to handle dimensions that don't exist in the tensor:
- "raise": Raise an error if any dimensions don't exist
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist
Returns
-------
XTensorVariable
Transposed tensor with reordered dimensions.
Raises
------
ValueError
If missing_dims="raise" and any dimensions don't exist.
If multiple ellipsis are provided.
"""
return px.shape.transpose(self, *dims, missing_dims=missing_dims)
def stack(self, dim, **dims):
return px.shape.stack(self, dim, **dims)
......
......@@ -4,12 +4,13 @@ import pytest
pytest.importorskip("xarray")
import re
from itertools import chain, combinations
import numpy as np
from xarray import concat as xr_concat
from pytensor.xtensor.shape import concat, stack
from pytensor.xtensor.shape import concat, stack, transpose
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import (
xr_arange_like,
......@@ -28,6 +29,88 @@ def powerset(iterable, min_group_size=0):
)
def test_transpose():
a, b, c, d, e = "abcde"
x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11))
permutations = [
(a, b, c, d, e), # identity
(e, d, c, b, a), # full tranpose
(), # eqivalent to full transpose
(a, b, c, e, d), # swap last two dims
(..., d, c), # equivalent to (a, b, e, d, c)
(b, a, ..., e, d), # equivalent to (b, a, c, d, e)
(c, a, ...), # equivalent to (c, a, b, d, e)
]
outs = [transpose(x, *perm) for perm in permutations]
fn = xr_function([x], outs)
x_test = xr_arange_like(x)
res = fn(x_test)
expected_res = [x_test.transpose(*perm) for perm in permutations]
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res):
xr_assert_allclose(res_i, expected_res_i)
def test_xtensor_variable_transpose():
"""Test the transpose() method of XTensorVariable."""
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
# Test basic transpose
out = x.transpose()
fn = xr_function([x], out)
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test), x_test.transpose())
# Test transpose with specific dimensions
out = x.transpose("c", "a", "b")
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b"))
# Test transpose with ellipsis
out = x.transpose("c", ...)
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
# Test error cases
with pytest.raises(
ValueError,
match=re.escape(
"Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')"
),
):
x.transpose("d")
with pytest.raises(
ValueError,
match=re.escape("Ellipsis (...) can only appear once in the dimensions"),
):
x.transpose("a", ..., "b", ...)
# Test missing_dims parameter
# Test ignore
out = x.transpose("c", ..., "d", missing_dims="ignore")
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
# Test warn
with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"):
out = x.transpose("c", ..., "d", missing_dims="warn")
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
def test_xtensor_variable_T():
"""Test the T property of XTensorVariable."""
# Test T property with 3D tensor
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
out = x.T
fn = xr_function([x], out)
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test), x_test.T)
def test_stack():
dims = ("a", "b", "c", "d")
x = xtensor("x", dims=dims, shape=(2, 3, 5, 7))
......
......@@ -33,15 +33,12 @@ def test_xtensortype_filter_variable():
assert x.type.filter_variable(y1) is y1
y2 = xtensor("y2", dims=("b", "a"), shape=(3, 2))
expected_y2 = as_xtensor(y2.values.transpose(), dims=("a", "b"))
expected_y2 = y2.transpose()
assert equal_computations([x.type.filter_variable(y2)], [expected_y2])
y3 = xtensor("y3", dims=("b", "a"), shape=(3, None))
expected_y3 = as_xtensor(
specify_shape(
as_xtensor(y3.values.transpose(), dims=("a", "b")).values, (2, 3)
),
dims=("a", "b"),
specify_shape(y3.transpose().values, (2, 3)), dims=("a", "b")
)
assert equal_computations([x.type.filter_variable(y3)], [expected_y3])
......@@ -116,7 +113,7 @@ def test_minimum_compile():
from pytensor.compile.mode import Mode
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = as_xtensor(x.values.transpose(), dims=("b", "a"))
y = x.transpose()
minimum_mode = Mode(linker="py", optimizer="minimum_compile")
result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode)
np.testing.assert_array_equal(result, np.ones((3, 2)))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论