提交 9ede7f6a authored 作者: Allen Downey's avatar Allen Downey 提交者: Ricardo Vieira

Implement expand_dims for XTensorVariables (#1449)

上级 9716b3f2
from pytensor.graph import node_rewriter
from pytensor.tensor import (
broadcast_to,
expand_dims,
join,
moveaxis,
specify_shape,
......@@ -10,6 +11,7 @@ from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import (
Concat,
ExpandDims,
Squeeze,
Stack,
Transpose,
......@@ -121,7 +123,7 @@ def lower_transpose(fgraph, node):
@register_lower_xtensor
@node_rewriter([Squeeze])
def local_squeeze_reshape(fgraph, node):
def lower_squeeze(fgraph, node):
"""Rewrite Squeeze to tensor.squeeze."""
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
......@@ -132,3 +134,33 @@ def local_squeeze_reshape(fgraph, node):
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]
@register_lower_xtensor
@node_rewriter([ExpandDims])
def lower_expand_dims(fgraph, node):
"""Rewrite ExpandDims using tensor operations."""
x, size = node.inputs
out = node.outputs[0]
# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
size_tensor = tensor_from_xtensor(size)
# Get the new dimension name and position
new_axis = 0 # Always insert at front
# Use tensor operations
if out.type.shape[0] == 1:
# Simple case: just expand with size 1
result_tensor = expand_dims(x_tensor, new_axis)
else:
# Otherwise broadcast to the requested size
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape))
# Preserve static shape information
result_tensor = specify_shape(result_tensor, out.type.shape)
# Convert result back to xtensor
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
return [result]
import typing
import warnings
from collections.abc import Sequence
from collections.abc import Hashable, Sequence
from types import EllipsisType
from typing import Literal
import numpy as np
from pytensor.graph import Apply
from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
......@@ -381,3 +384,121 @@ def squeeze(x, dim=None, drop=False, axis=None):
return x # no-op if nothing to squeeze
return Squeeze(dims=dims)(x)
class ExpandDims(XOp):
"""Add a new dimension to an XTensorVariable."""
__props__ = ("dim",)
def __init__(self, dim):
if not isinstance(dim, str):
raise TypeError(f"`dim` must be a string, got: {type(self.dim)}")
self.dim = dim
def make_node(self, x, size):
x = as_xtensor(x)
if self.dim in x.type.dims:
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")
size = as_xtensor(size, dims=())
if not (size.dtype in integer_dtypes and size.ndim == 0):
raise ValueError(f"size should be an integer scalar, got {size.type}")
try:
static_size = int(get_scalar_constant_value(size))
except NotScalarConstantError:
static_size = None
# If size is a constant, validate it
if static_size is not None and static_size < 0:
raise ValueError(f"size must be 0 or positive, got: {static_size}")
new_shape = (static_size, *x.type.shape)
# Insert new dim at front
new_dims = (self.dim, *x.type.dims)
out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x, size], [out])
def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable."""
x = as_xtensor(x)
# Store original dimensions for axis handling
original_dims = x.type.dims
# Warn if create_index_for_new_dim is used (not supported)
if create_index_for_new_dim is not None:
warnings.warn(
"create_index_for_new_dim=False has no effect in pytensor.xtensor",
UserWarning,
stacklevel=2,
)
if dim is None:
dim = dim_kwargs
elif dim_kwargs:
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")
# Check that dim is Hashable or a sequence of Hashable or dict
if not isinstance(dim, Hashable):
if not isinstance(dim, Sequence | dict):
raise TypeError(f"unhashable type: {type(dim).__name__}")
if not all(isinstance(d, Hashable) for d in dim):
raise TypeError(f"unhashable type in {type(dim).__name__}")
# Normalize to a dimension-size mapping
if isinstance(dim, str):
dims_dict = {dim: 1}
elif isinstance(dim, Sequence) and not isinstance(dim, dict):
dims_dict = {d: 1 for d in dim}
elif isinstance(dim, dict):
dims_dict = {}
for name, val in dim.items():
if isinstance(val, str):
raise TypeError(f"Dimension size cannot be a string: {val}")
if isinstance(val, Sequence | np.ndarray):
warnings.warn(
"When a sequence is provided as a dimension size, only its length is used. "
"The actual values (which would be coordinates in xarray) are ignored.",
UserWarning,
stacklevel=2,
)
dims_dict[name] = len(val)
else:
# should be int or symbolic scalar
dims_dict[name] = val
else:
raise TypeError(f"Invalid type for `dim`: {type(dim)}")
# Insert each new dim at the front (reverse order preserves user intent)
for name, size in reversed(dims_dict.items()):
x = ExpandDims(dim=name)(x, size)
# If axis is specified, transpose to put new dimensions in the right place
if axis is not None:
# Wrap non-sequence axis in a list
if not isinstance(axis, Sequence):
axis = [axis]
# require len(axis) == len(dims_dict)
if len(axis) != len(dims_dict):
raise ValueError("lengths of dim and axis should be identical.")
# Insert new dimensions at their specified positions
target_dims = list(original_dims)
for name, pos in zip(dims_dict, axis):
# Convert negative axis to positive position relative to current dims
if pos < 0:
pos = len(target_dims) + pos + 1
target_dims.insert(pos, name)
x = Transpose(dims=tuple(target_dims))(x)
return x
......@@ -573,6 +573,47 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
"""
return px.shape.squeeze(self, dim, drop, axis)
def expand_dims(
self,
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
create_index_for_new_dim: bool = True,
axis: int | Sequence[int] | None = None,
**dim_kwargs,
):
"""Add one or more new dimensions to the tensor.
Parameters
----------
dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
Currently ignored. Reserved for future coordinate support.
In xarray, when True (default), creates a coordinate index for the new dimension
with values from 0 to size-1. When False, no coordinate index is created.
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
By default (None), new dimensions are inserted at the beginning (axis=0).
Symbolic axis is not supported yet.
Negative values count from the end.
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.
Returns
-------
XTensorVariable
A tensor with additional dimensions inserted at the front.
"""
return px.shape.expand_dims(
self,
dim,
create_index_for_new_dim=create_index_for_new_dim,
axis=axis,
**dim_kwargs,
)
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
......
......@@ -8,10 +8,10 @@ import re
from itertools import chain, combinations
import numpy as np
import pytest
from xarray import DataArray
from xarray import concat as xr_concat
from pytensor.tensor import scalar
from pytensor.xtensor.shape import (
concat,
stack,
......@@ -356,3 +356,113 @@ def test_squeeze_errors():
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
fn2(x2_test)
def test_expand_dims():
"""Test expand_dims."""
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
x_test = xr_arange_like(x)
# Implicit size 1
y = x.expand_dims("country")
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))
# Test with multiple dimensions
y = x.expand_dims(["country", "state"])
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"]))
# Test with a dict of name-size pairs
y = x.expand_dims({"country": 2, "state": 3})
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3}))
# Test with kwargs (equivalent to dict)
y = x.expand_dims(country=2, state=3)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3))
# Test with a dict of name-coord array pairs
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test),
x_test.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}),
)
# Symbolic size 1
size_sym_1 = scalar("size_sym_1", dtype="int64")
y = x.expand_dims({"country": size_sym_1})
fn = xr_function([x, size_sym_1], y)
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1}))
# Test with symbolic sizes in dict
size_sym_2 = scalar("size_sym_2", dtype="int64")
y = x.expand_dims({"country": size_sym_1, "state": size_sym_2})
fn = xr_function([x, size_sym_1, size_sym_2], y)
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
# Test with symbolic sizes in kwargs
y = x.expand_dims(country=size_sym_1, state=size_sym_2)
fn = xr_function([x, size_sym_1, size_sym_2], y)
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
# Test with axis parameter
y = x.expand_dims("country", axis=1)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1))
# Test with negative axis parameter
y = x.expand_dims("country", axis=-1)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1))
# Add two new dims with axis parameters
y = x.expand_dims(["country", "state"], axis=[1, 2])
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2])
)
# Add two dims with negative axis parameters
y = x.expand_dims(["country", "state"], axis=[-1, -2])
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-1, -2])
)
# Add two dims with positive and negative axis parameters
y = x.expand_dims(["country", "state"], axis=[-2, 1])
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-2, 1])
)
def test_expand_dims_errors():
"""Test error handling in expand_dims."""
# Expanding existing dim
x = xtensor("x", dims=("city",), shape=(3,))
y = x.expand_dims("country")
with pytest.raises(ValueError, match="already exists"):
y.expand_dims("city")
# Invalid dim type
with pytest.raises(TypeError, match="Invalid type for `dim`"):
x.expand_dims(123)
# Duplicate dimension creation
y = x.expand_dims("new")
with pytest.raises(ValueError, match="already exists"):
y.expand_dims("new")
# Find out what xarray does with a numpy array as dim
# x_test = xr_arange_like(x)
# x_test.expand_dims(np.array([1, 2]))
# TypeError: unhashable type: 'numpy.ndarray'
# Test with a numpy array as dim (not supported)
with pytest.raises(TypeError, match="unhashable type"):
y.expand_dims(np.array([1, 2]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论