提交 9716b3f2 authored 作者: Allen Downey's avatar Allen Downey 提交者: Ricardo Vieira

Implement squeeze for XTensorVariables

上级 071c4eb8
from pytensor.graph import node_rewriter
from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape
from pytensor.tensor import (
broadcast_to,
join,
moveaxis,
specify_shape,
squeeze,
)
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack
from pytensor.xtensor.shape import (
Concat,
Squeeze,
Stack,
Transpose,
UnStack,
)
@register_lower_xtensor
......@@ -105,3 +117,18 @@ def lower_transpose(fgraph, node):
x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
return [new_out]
@register_lower_xtensor
@node_rewriter([Squeeze])
def local_squeeze_reshape(fgraph, node):
"""Rewrite Squeeze to tensor.squeeze."""
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
x_dims = x.type.dims
dims_to_remove = node.op.dims
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]
......@@ -297,3 +297,87 @@ class Concat(XOp):
def concat(xtensors, dim: str):
return Concat(dim=dim)(*xtensors)
class Squeeze(XOp):
"""Remove specified dimensions from an XTensorVariable.
Only dimensions that are known statically to be size 1 will be removed.
Symbolic dimensions must be explicitly specified, and are assumed safe.
Parameters
----------
dim : tuple of str
The names of the dimensions to remove.
"""
__props__ = ("dims",)
def __init__(self, dims):
self.dims = tuple(sorted(set(dims)))
def make_node(self, x):
x = as_xtensor(x)
# Validate that dims exist and are size-1 if statically known
dims_to_remove = []
x_dims = x.type.dims
x_shape = x.type.shape
for d in self.dims:
if d not in x_dims:
raise ValueError(f"Dimension {d} not found in {x.type.dims}")
idx = x_dims.index(d)
dim_size = x_shape[idx]
if dim_size is not None and dim_size != 1:
raise ValueError(f"Dimension {d} has static size {dim_size}, not 1")
dims_to_remove.append(idx)
new_dims = tuple(
d for i, d in enumerate(x.type.dims) if i not in dims_to_remove
)
new_shape = tuple(
s for i, s in enumerate(x.type.shape) if i not in dims_to_remove
)
out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x], [out])
def squeeze(x, dim=None, drop=False, axis=None):
"""Remove dimensions of size 1 from an XTensorVariable."""
x = as_xtensor(x)
# drop parameter is ignored in pytensor.xtensor
if drop is not None:
warnings.warn("drop parameter has no effect in pytensor.xtensor", UserWarning)
# dim and axis are mutually exclusive
if dim is not None and axis is not None:
raise ValueError("Cannot specify both `dim` and `axis`")
# if axis is specified, it must be a sequence of ints
if axis is not None:
if not isinstance(axis, Sequence):
axis = [axis]
if not all(isinstance(a, int) for a in axis):
raise ValueError("axis must be an integer or a sequence of integers")
# convert axis to dims
dims = tuple(x.type.dims[i] for i in axis)
# if dim is specified, it must be a string or a sequence of strings
if dim is None:
dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1)
elif isinstance(dim, str):
dims = (dim,)
else:
dims = tuple(dim)
if not dims:
return x # no-op if nothing to squeeze
return Squeeze(dims=dims)(x)
......@@ -547,6 +547,32 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin")
def squeeze(
self,
dim: Sequence[str] | str | None = None,
drop=None,
axis: int | Sequence[int] | None = None,
):
"""Remove dimensions of size 1 from an XTensorVariable.
Parameters
----------
x : XTensorVariable
The input tensor
dim : str or None or iterable of str, optional
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
drop : bool, optional
If drop=True, drop squeezed coordinates instead of making them scalar.
axis : int or iterable of int, optional
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
Returns
-------
XTensorVariable
A new tensor with the specified dimension(s) removed.
"""
return px.shape.squeeze(self, dim, drop, axis)
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
......
......@@ -8,10 +8,16 @@ import re
from itertools import chain, combinations
import numpy as np
import pytest
from xarray import DataArray
from xarray import concat as xr_concat
from pytensor.xtensor.shape import concat, stack, transpose, unstack
from pytensor.xtensor.shape import (
concat,
stack,
transpose,
unstack,
)
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import (
xr_arange_like,
......@@ -21,6 +27,9 @@ from tests.xtensor.util import (
)
pytest.importorskip("xarray")
def powerset(iterable, min_group_size=0):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
......@@ -256,3 +265,94 @@ def test_concat_scalar():
res = fn(x1_test, x2_test)
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
xr_assert_allclose(res, expected_res)
def test_squeeze():
"""Test squeeze."""
# Single dimension
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
y1 = x1.squeeze("country")
fn1 = xr_function([x1], y1)
x1_test = xr_arange_like(x1)
xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country"))
# Multiple dimensions and order independence
x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3))
y2a = x2.squeeze(["b", "c"])
y2b = x2.squeeze(["c", "b"]) # Test order independence
y2c = x2.squeeze(["b", "b"]) # Test redundant dimensions
y2d = x2.squeeze([]) # Test empty list (no-op)
fn2a = xr_function([x2], y2a)
fn2b = xr_function([x2], y2b)
fn2c = xr_function([x2], y2c)
fn2d = xr_function([x2], y2d)
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn2a(x2_test), x2_test.squeeze(["b", "c"]))
xr_assert_allclose(fn2b(x2_test), x2_test.squeeze(["c", "b"]))
xr_assert_allclose(fn2c(x2_test), x2_test.squeeze(["b", "b"]))
xr_assert_allclose(fn2d(x2_test), x2_test)
# Unknown shapes
x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown
y3 = x3.squeeze("b")
x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3)))
fn3 = xr_function([x3], y3)
xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b"))
# Mixed known + unknown shapes
x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3))
y4 = x4.squeeze("b")
x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3)))
fn4 = xr_function([x4], y4)
xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b"))
# Test axis parameter
x5 = xtensor("x5", dims=("a", "b", "c"), shape=(2, 1, 3))
y5 = x5.squeeze(axis=1) # squeeze dimension at index 1 (b)
fn5 = xr_function([x5], y5)
x5_test = xr_arange_like(x5)
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=1))
# Test axis parameter with negative index
y5 = x5.squeeze(axis=-1) # squeeze dimension at index -2 (b)
fn5 = xr_function([x5], y5)
x5_test = xr_arange_like(x5)
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=-2))
# Test axis parameter with sequence of ints
y6 = x2.squeeze(axis=[1, 2])
fn6 = xr_function([x2], y6)
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn6(x2_test), x2_test.squeeze(axis=[1, 2]))
# Test drop parameter warning
x7 = xtensor("x7", dims=("a", "b"), shape=(2, 1))
with pytest.warns(
UserWarning, match="drop parameter has no effect in pytensor.xtensor"
):
y7 = x7.squeeze("b", drop=True) # squeeze and drop coordinate
fn7 = xr_function([x7], y7)
x7_test = xr_arange_like(x7)
xr_assert_allclose(fn7(x7_test), x7_test.squeeze("b", drop=True))
def test_squeeze_errors():
"""Test error cases for squeeze."""
# Non-existent dimension
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
with pytest.raises(ValueError, match="Dimension .* not found"):
x1.squeeze("time")
# Dimension size > 1
with pytest.raises(ValueError, match="has static size .* not 1"):
x1.squeeze("city")
# Symbolic shape: dim is not 1 at runtime → should raise
x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown
y2 = x2.squeeze("b")
x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3)))
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
fn2(x2_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论