提交 cd1e5dc9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement stack for XTensorVariables

上级 155db9f3
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.shape
from pytensor.graph import node_rewriter
from pytensor.tensor import moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import Stack
@register_lower_xtensor
@node_rewriter(tracks=[Stack])
def lower_stack(fgraph, node):
[x] = node.inputs
batch_ndim = x.type.ndim - len(node.op.stacked_dims)
stacked_axes = [
i for i, dim in enumerate(x.type.dims) if dim in node.op.stacked_dims
]
end = tuple(range(-len(stacked_axes), 0))
x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = moveaxis(x_tensor, source=stacked_axes, destination=end)
if batch_ndim == (x.type.ndim - 1):
# This happens when we stack a "single" dimension, in this case all we need is the transpose
# Note: If we have meaningful rewrites before lowering, consider canonicalizing this as a Transpose + Rename
final_tensor = x_tensor_transposed
else:
final_shape = (*tuple(x_tensor_transposed.shape)[:batch_ndim], -1)
final_tensor = x_tensor_transposed.reshape(final_shape)
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out]
from collections.abc import Sequence
from pytensor.graph import Apply
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
class Stack(XOp):
__props__ = ("new_dim_name", "stacked_dims")
def __init__(self, new_dim_name: str, stacked_dims: tuple[str, ...]):
super().__init__()
if new_dim_name in stacked_dims:
raise ValueError(
f"Stacking dim {new_dim_name} must not be in {stacked_dims}"
)
if not stacked_dims:
raise ValueError(f"Stacking dims must not be empty: got {stacked_dims}")
self.new_dim_name = new_dim_name
self.stacked_dims = stacked_dims
def make_node(self, x):
x = as_xtensor(x)
if not (set(self.stacked_dims) <= set(x.type.dims)):
raise ValueError(
f"Stacking dims {self.stacked_dims} must be a subset of {x.type.dims}"
)
if self.new_dim_name in x.type.dims:
raise ValueError(
f"Stacking dim {self.new_dim_name} must not be in {x.type.dims}"
)
if len(self.stacked_dims) == x.type.ndim:
batch_dims, batch_shape = (), ()
else:
batch_dims, batch_shape = zip(
*(
(dim, shape)
for dim, shape in zip(x.type.dims, x.type.shape)
if dim not in self.stacked_dims
)
)
stack_shape = 1
for dim, shape in zip(x.type.dims, x.type.shape):
if dim in self.stacked_dims:
if shape is None:
stack_shape = None
break
else:
stack_shape *= shape
output = xtensor(
dtype=x.type.dtype,
shape=(*batch_shape, stack_shape),
dims=(*batch_dims, self.new_dim_name),
)
return Apply(self, [x], [output])
def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
if dim is not None:
if dims:
raise ValueError("Cannot use both positional dim and keyword dims in stack")
dims = dim
y = x
for new_dim_name, stacked_dims in dims.items():
if isinstance(stacked_dims, str):
raise TypeError(
f"Stacking dims must be a sequence of strings, got a single string: {stacked_dims}"
)
y = Stack(new_dim_name, tuple(stacked_dims))(y)
return y
......@@ -311,6 +311,11 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def __getitem__(self, idx):
raise NotImplementedError("Indexing not yet implemnented")
# Reshaping and reorganizing
# https://docs.xarray.dev/en/latest/api.html#id8
def stack(self, dim, **dims):
return px.shape.stack(self, dim, **dims)
class XTensorConstantSignature(TensorConstantSignature):
pass
......
# ruff: noqa: E402
import pytest
pytest.importorskip("xarray")
from itertools import chain, combinations
from pytensor.xtensor.shape import stack
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import (
xr_arange_like,
xr_assert_allclose,
xr_function,
)
def powerset(iterable, min_group_size=0):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
s = list(iterable)
return chain.from_iterable(
combinations(s, r) for r in range(min_group_size, len(s) + 1)
)
def test_stack():
dims = ("a", "b", "c", "d")
x = xtensor("x", dims=dims, shape=(2, 3, 5, 7))
outs = [
stack(x, new_dim=dims_to_stack)
for dims_to_stack in powerset(dims, min_group_size=2)
]
fn = xr_function([x], outs)
x_test = xr_arange_like(x)
res = fn(x_test)
expected_res = [
x_test.stack(new_dim=dims_to_stack)
for dims_to_stack in powerset(dims, min_group_size=2)
]
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res):
xr_assert_allclose(res_i, expected_res_i)
def test_stack_single_dim():
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 5))
out = stack(x, {"d": ["a"]})
assert out.type.dims == ("b", "c", "d")
fn = xr_function([x], out)
x_test = xr_arange_like(x)
res = fn(x_test)
expected_res = x_test.stack(d=["a"])
xr_assert_allclose(res, expected_res)
def test_multiple_stacks():
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7))
out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d"))
fn = xr_function([x], [out])
x_test = xr_arange_like(x)
res = fn(x_test)
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
xr_assert_allclose(res[0], expected_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论