提交 9716b3f2 authored 作者: Allen Downey's avatar Allen Downey 提交者: Ricardo Vieira

Implement squeeze for XTensorVariables

上级 071c4eb8
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape from pytensor.tensor import (
broadcast_to,
join,
moveaxis,
specify_shape,
squeeze,
)
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack from pytensor.xtensor.shape import (
Concat,
Squeeze,
Stack,
Transpose,
UnStack,
)
@register_lower_xtensor @register_lower_xtensor
...@@ -105,3 +117,18 @@ def lower_transpose(fgraph, node): ...@@ -105,3 +117,18 @@ def lower_transpose(fgraph, node):
x_tensor_transposed = x_tensor.transpose(perm) x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
return [new_out] return [new_out]
@register_lower_xtensor
@node_rewriter([Squeeze])
def local_squeeze_reshape(fgraph, node):
"""Rewrite Squeeze to tensor.squeeze."""
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
x_dims = x.type.dims
dims_to_remove = node.op.dims
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]
...@@ -297,3 +297,87 @@ class Concat(XOp): ...@@ -297,3 +297,87 @@ class Concat(XOp):
def concat(xtensors, dim: str): def concat(xtensors, dim: str):
return Concat(dim=dim)(*xtensors) return Concat(dim=dim)(*xtensors)
class Squeeze(XOp):
"""Remove specified dimensions from an XTensorVariable.
Only dimensions that are known statically to be size 1 will be removed.
Symbolic dimensions must be explicitly specified, and are assumed safe.
Parameters
----------
dim : tuple of str
The names of the dimensions to remove.
"""
__props__ = ("dims",)
def __init__(self, dims):
self.dims = tuple(sorted(set(dims)))
def make_node(self, x):
x = as_xtensor(x)
# Validate that dims exist and are size-1 if statically known
dims_to_remove = []
x_dims = x.type.dims
x_shape = x.type.shape
for d in self.dims:
if d not in x_dims:
raise ValueError(f"Dimension {d} not found in {x.type.dims}")
idx = x_dims.index(d)
dim_size = x_shape[idx]
if dim_size is not None and dim_size != 1:
raise ValueError(f"Dimension {d} has static size {dim_size}, not 1")
dims_to_remove.append(idx)
new_dims = tuple(
d for i, d in enumerate(x.type.dims) if i not in dims_to_remove
)
new_shape = tuple(
s for i, s in enumerate(x.type.shape) if i not in dims_to_remove
)
out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x], [out])
def squeeze(x, dim=None, drop=False, axis=None):
"""Remove dimensions of size 1 from an XTensorVariable."""
x = as_xtensor(x)
# drop parameter is ignored in pytensor.xtensor
if drop is not None:
warnings.warn("drop parameter has no effect in pytensor.xtensor", UserWarning)
# dim and axis are mutually exclusive
if dim is not None and axis is not None:
raise ValueError("Cannot specify both `dim` and `axis`")
# if axis is specified, it must be a sequence of ints
if axis is not None:
if not isinstance(axis, Sequence):
axis = [axis]
if not all(isinstance(a, int) for a in axis):
raise ValueError("axis must be an integer or a sequence of integers")
# convert axis to dims
dims = tuple(x.type.dims[i] for i in axis)
# if dim is specified, it must be a string or a sequence of strings
if dim is None:
dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1)
elif isinstance(dim, str):
dims = (dim,)
else:
dims = tuple(dim)
if not dims:
return x # no-op if nothing to squeeze
return Squeeze(dims=dims)(x)
...@@ -547,6 +547,32 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -547,6 +547,32 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin") return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin")
def squeeze(
self,
dim: Sequence[str] | str | None = None,
drop=None,
axis: int | Sequence[int] | None = None,
):
"""Remove dimensions of size 1 from an XTensorVariable.
Parameters
----------
x : XTensorVariable
The input tensor
dim : str or None or iterable of str, optional
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
drop : bool, optional
If drop=True, drop squeezed coordinates instead of making them scalar.
axis : int or iterable of int, optional
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
Returns
-------
XTensorVariable
A new tensor with the specified dimension(s) removed.
"""
return px.shape.squeeze(self, dim, drop, axis)
# ndarray methods # ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7 # https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max): def clip(self, min, max):
......
...@@ -8,10 +8,16 @@ import re ...@@ -8,10 +8,16 @@ import re
from itertools import chain, combinations from itertools import chain, combinations
import numpy as np import numpy as np
import pytest
from xarray import DataArray from xarray import DataArray
from xarray import concat as xr_concat from xarray import concat as xr_concat
from pytensor.xtensor.shape import concat, stack, transpose, unstack from pytensor.xtensor.shape import (
concat,
stack,
transpose,
unstack,
)
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import ( from tests.xtensor.util import (
xr_arange_like, xr_arange_like,
...@@ -21,6 +27,9 @@ from tests.xtensor.util import ( ...@@ -21,6 +27,9 @@ from tests.xtensor.util import (
) )
pytest.importorskip("xarray")
def powerset(iterable, min_group_size=0): def powerset(iterable, min_group_size=0):
"Subsequences of the iterable from shortest to longest." "Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) # powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
...@@ -256,3 +265,94 @@ def test_concat_scalar(): ...@@ -256,3 +265,94 @@ def test_concat_scalar():
res = fn(x1_test, x2_test) res = fn(x1_test, x2_test)
expected_res = xr_concat([x1_test, x2_test], dim="new_dim") expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
xr_assert_allclose(res, expected_res) xr_assert_allclose(res, expected_res)
def test_squeeze():
"""Test squeeze."""
# Single dimension
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
y1 = x1.squeeze("country")
fn1 = xr_function([x1], y1)
x1_test = xr_arange_like(x1)
xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country"))
# Multiple dimensions and order independence
x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3))
y2a = x2.squeeze(["b", "c"])
y2b = x2.squeeze(["c", "b"]) # Test order independence
y2c = x2.squeeze(["b", "b"]) # Test redundant dimensions
y2d = x2.squeeze([]) # Test empty list (no-op)
fn2a = xr_function([x2], y2a)
fn2b = xr_function([x2], y2b)
fn2c = xr_function([x2], y2c)
fn2d = xr_function([x2], y2d)
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn2a(x2_test), x2_test.squeeze(["b", "c"]))
xr_assert_allclose(fn2b(x2_test), x2_test.squeeze(["c", "b"]))
xr_assert_allclose(fn2c(x2_test), x2_test.squeeze(["b", "b"]))
xr_assert_allclose(fn2d(x2_test), x2_test)
# Unknown shapes
x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown
y3 = x3.squeeze("b")
x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3)))
fn3 = xr_function([x3], y3)
xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b"))
# Mixed known + unknown shapes
x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3))
y4 = x4.squeeze("b")
x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3)))
fn4 = xr_function([x4], y4)
xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b"))
# Test axis parameter
x5 = xtensor("x5", dims=("a", "b", "c"), shape=(2, 1, 3))
y5 = x5.squeeze(axis=1) # squeeze dimension at index 1 (b)
fn5 = xr_function([x5], y5)
x5_test = xr_arange_like(x5)
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=1))
# Test axis parameter with negative index
y5 = x5.squeeze(axis=-1) # squeeze dimension at index -2 (b)
fn5 = xr_function([x5], y5)
x5_test = xr_arange_like(x5)
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=-2))
# Test axis parameter with sequence of ints
y6 = x2.squeeze(axis=[1, 2])
fn6 = xr_function([x2], y6)
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn6(x2_test), x2_test.squeeze(axis=[1, 2]))
# Test drop parameter warning
x7 = xtensor("x7", dims=("a", "b"), shape=(2, 1))
with pytest.warns(
UserWarning, match="drop parameter has no effect in pytensor.xtensor"
):
y7 = x7.squeeze("b", drop=True) # squeeze and drop coordinate
fn7 = xr_function([x7], y7)
x7_test = xr_arange_like(x7)
xr_assert_allclose(fn7(x7_test), x7_test.squeeze("b", drop=True))
def test_squeeze_errors():
"""Test error cases for squeeze."""
# Non-existent dimension
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
with pytest.raises(ValueError, match="Dimension .* not found"):
x1.squeeze("time")
# Dimension size > 1
with pytest.raises(ValueError, match="has static size .* not 1"):
x1.squeeze("city")
# Symbolic shape: dim is not 1 at runtime → should raise
x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown
y2 = x2.squeeze("b")
x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3)))
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
fn2(x2_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论