提交 133ec80e authored 作者: Allen Downey's avatar Allen Downey 提交者: Ricardo Vieira

Implement transpose for XTensorVariables

上级 30b50fda
...@@ -2,7 +2,7 @@ from pytensor.graph import node_rewriter ...@@ -2,7 +2,7 @@ from pytensor.graph import node_rewriter
from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.tensor import broadcast_to, join, moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import Concat, Stack from pytensor.xtensor.shape import Concat, Stack, Transpose
@register_lower_xtensor @register_lower_xtensor
...@@ -70,3 +70,19 @@ def lower_concat(fgraph, node): ...@@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
joined_tensor = join(concat_axis, *bcast_tensor_inputs) joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out] return [new_out]
@register_lower_xtensor
@node_rewriter(tracks=[Transpose])
def lower_transpose(fgraph, node):
[x] = node.inputs
# Use the final dimensions that were already computed in make_node
out_dims = node.outputs[0].type.dims
in_dims = x.type.dims
# Compute the permutation based on the final dimensions
perm = tuple(in_dims.index(d) for d in out_dims)
x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
return [new_out]
import typing
import warnings
from collections.abc import Sequence from collections.abc import Sequence
from types import EllipsisType
from typing import Literal
from pytensor.graph import Apply from pytensor.graph import Apply
from pytensor.scalar import upcast from pytensor.scalar import upcast
...@@ -72,6 +76,92 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) ...@@ -72,6 +76,92 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
return y return y
class Transpose(XOp):
__props__ = ("dims",)
def __init__(
self,
dims: Sequence[str],
):
super().__init__()
self.dims = tuple(dims)
def make_node(self, x):
x = as_xtensor(x)
transpose_dims = self.dims
x_shape = x.type.shape
x_dims = x.type.dims
if set(transpose_dims) != set(x_dims):
raise ValueError(f"{transpose_dims} must be a permuted list of {x_dims}")
output = xtensor(
dtype=x.type.dtype,
shape=tuple(x_shape[x_dims.index(d)] for d in transpose_dims),
dims=transpose_dims,
)
return Apply(self, [x], [output])
def transpose(
x,
*dims: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
):
"""Transpose dimensions of the tensor.
Parameters
----------
x : XTensorVariable
Input tensor to transpose.
*dims : str
Dimensions to transpose to. Can include ellipsis (...) to represent
remaining dimensions in their original order.
missing_dims : {"raise", "warn", "ignore"}, optional
How to handle dimensions that don't exist in the input tensor:
- "raise": Raise an error if any dimensions don't exist (default)
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist
Returns
-------
XTensorVariable
Transposed tensor with reordered dimensions.
Raises
------
ValueError
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
"""
# Validate dimensions
x = as_xtensor(x)
x_dims = x.type.dims
invalid_dims = set(dims) - {..., *x_dims}
if invalid_dims:
if missing_dims != "ignore":
msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {x_dims}"
if missing_dims == "raise":
raise ValueError(msg)
else:
warnings.warn(msg)
# Handle missing dimensions if not raising
dims = tuple(d for d in dims if d in x_dims or d is ...)
if dims == () or dims == (...,):
dims = tuple(reversed(x_dims))
elif ... in dims:
if dims.count(...) > 1:
raise ValueError("Ellipsis (...) can only appear once in the dimensions")
# Handle ellipsis expansion
ellipsis_idx = dims.index(...)
pre = dims[:ellipsis_idx]
post = dims[ellipsis_idx + 1 :]
middle = [d for d in x_dims if d not in pre + post]
dims = (*pre, *middle, *post)
return Transpose(typing.cast(tuple[str], dims))(x)
class Concat(XOp): class Concat(XOp):
__props__ = ("dim",) __props__ = ("dim",)
......
import typing import typing
from types import EllipsisType
from pytensor.compile import ( from pytensor.compile import (
DeepCopyOp, DeepCopyOp,
...@@ -23,7 +24,7 @@ except ModuleNotFoundError: ...@@ -23,7 +24,7 @@ except ModuleNotFoundError:
XARRAY_AVAILABLE = False XARRAY_AVAILABLE = False
from collections.abc import Sequence from collections.abc import Sequence
from typing import TypeVar from typing import Literal, TypeVar
import numpy as np import numpy as np
...@@ -438,6 +439,19 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -438,6 +439,19 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def real(self): def real(self):
return px.math.real(self) return px.math.real(self)
@property
def T(self):
"""Return the full transpose of the tensor.
This is equivalent to calling transpose() with no arguments.
Returns
-------
XTensorVariable
Fully transposed tensor.
"""
return self.transpose()
# Aggregation # Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6 # https://docs.xarray.dev/en/latest/api.html#id6
def all(self, dim=None): def all(self, dim=None):
...@@ -475,6 +489,37 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -475,6 +489,37 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# Reshaping and reorganizing # Reshaping and reorganizing
# https://docs.xarray.dev/en/latest/api.html#id8 # https://docs.xarray.dev/en/latest/api.html#id8
def transpose(
self,
*dims: str | EllipsisType,
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
):
"""Transpose dimensions of the tensor.
Parameters
----------
*dims : str | Ellipsis
Dimensions to transpose. If empty, performs a full transpose.
Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise"
How to handle dimensions that don't exist in the tensor:
- "raise": Raise an error if any dimensions don't exist
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist
Returns
-------
XTensorVariable
Transposed tensor with reordered dimensions.
Raises
------
ValueError
If missing_dims="raise" and any dimensions don't exist.
If multiple ellipsis are provided.
"""
return px.shape.transpose(self, *dims, missing_dims=missing_dims)
def stack(self, dim, **dims): def stack(self, dim, **dims):
return px.shape.stack(self, dim, **dims) return px.shape.stack(self, dim, **dims)
......
...@@ -4,12 +4,13 @@ import pytest ...@@ -4,12 +4,13 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
import re
from itertools import chain, combinations from itertools import chain, combinations
import numpy as np import numpy as np
from xarray import concat as xr_concat from xarray import concat as xr_concat
from pytensor.xtensor.shape import concat, stack from pytensor.xtensor.shape import concat, stack, transpose
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import ( from tests.xtensor.util import (
xr_arange_like, xr_arange_like,
...@@ -28,6 +29,88 @@ def powerset(iterable, min_group_size=0): ...@@ -28,6 +29,88 @@ def powerset(iterable, min_group_size=0):
) )
def test_transpose():
a, b, c, d, e = "abcde"
x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11))
permutations = [
(a, b, c, d, e), # identity
(e, d, c, b, a), # full tranpose
(), # eqivalent to full transpose
(a, b, c, e, d), # swap last two dims
(..., d, c), # equivalent to (a, b, e, d, c)
(b, a, ..., e, d), # equivalent to (b, a, c, d, e)
(c, a, ...), # equivalent to (c, a, b, d, e)
]
outs = [transpose(x, *perm) for perm in permutations]
fn = xr_function([x], outs)
x_test = xr_arange_like(x)
res = fn(x_test)
expected_res = [x_test.transpose(*perm) for perm in permutations]
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res):
xr_assert_allclose(res_i, expected_res_i)
def test_xtensor_variable_transpose():
"""Test the transpose() method of XTensorVariable."""
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
# Test basic transpose
out = x.transpose()
fn = xr_function([x], out)
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test), x_test.transpose())
# Test transpose with specific dimensions
out = x.transpose("c", "a", "b")
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b"))
# Test transpose with ellipsis
out = x.transpose("c", ...)
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
# Test error cases
with pytest.raises(
ValueError,
match=re.escape(
"Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')"
),
):
x.transpose("d")
with pytest.raises(
ValueError,
match=re.escape("Ellipsis (...) can only appear once in the dimensions"),
):
x.transpose("a", ..., "b", ...)
# Test missing_dims parameter
# Test ignore
out = x.transpose("c", ..., "d", missing_dims="ignore")
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
# Test warn
with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"):
out = x.transpose("c", ..., "d", missing_dims="warn")
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
def test_xtensor_variable_T():
"""Test the T property of XTensorVariable."""
# Test T property with 3D tensor
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
out = x.T
fn = xr_function([x], out)
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test), x_test.T)
def test_stack(): def test_stack():
dims = ("a", "b", "c", "d") dims = ("a", "b", "c", "d")
x = xtensor("x", dims=dims, shape=(2, 3, 5, 7)) x = xtensor("x", dims=dims, shape=(2, 3, 5, 7))
......
...@@ -33,15 +33,12 @@ def test_xtensortype_filter_variable(): ...@@ -33,15 +33,12 @@ def test_xtensortype_filter_variable():
assert x.type.filter_variable(y1) is y1 assert x.type.filter_variable(y1) is y1
y2 = xtensor("y2", dims=("b", "a"), shape=(3, 2)) y2 = xtensor("y2", dims=("b", "a"), shape=(3, 2))
expected_y2 = as_xtensor(y2.values.transpose(), dims=("a", "b")) expected_y2 = y2.transpose()
assert equal_computations([x.type.filter_variable(y2)], [expected_y2]) assert equal_computations([x.type.filter_variable(y2)], [expected_y2])
y3 = xtensor("y3", dims=("b", "a"), shape=(3, None)) y3 = xtensor("y3", dims=("b", "a"), shape=(3, None))
expected_y3 = as_xtensor( expected_y3 = as_xtensor(
specify_shape( specify_shape(y3.transpose().values, (2, 3)), dims=("a", "b")
as_xtensor(y3.values.transpose(), dims=("a", "b")).values, (2, 3)
),
dims=("a", "b"),
) )
assert equal_computations([x.type.filter_variable(y3)], [expected_y3]) assert equal_computations([x.type.filter_variable(y3)], [expected_y3])
...@@ -116,7 +113,7 @@ def test_minimum_compile(): ...@@ -116,7 +113,7 @@ def test_minimum_compile():
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = as_xtensor(x.values.transpose(), dims=("b", "a")) y = x.transpose()
minimum_mode = Mode(linker="py", optimizer="minimum_compile") minimum_mode = Mode(linker="py", optimizer="minimum_compile")
result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode) result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode)
np.testing.assert_array_equal(result, np.ones((3, 2))) np.testing.assert_array_equal(result, np.ones((3, 2)))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论