提交 cc4b77a0 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove useless arguments from non xarray API

Make sure we don't issue unexpected warnings
上级 8da7cd76
......@@ -163,6 +163,7 @@ lines-after-imports = 2
"tests/link/numba/**/test_*.py" = ["E402"]
"tests/link/pytorch/**/test_*.py" = ["E402"]
"tests/link/mlx/**/test_*.py" = ["E402"]
"tests/xtensor/**/test_*.py" = ["E402"]
......
import typing
import warnings
from collections.abc import Hashable, Sequence
from collections.abc import Sequence
from types import EllipsisType
from typing import Literal
......@@ -384,28 +384,10 @@ class Squeeze(XOp):
return Apply(self, [x], [out])
def squeeze(x, dim=None, drop=False, axis=None):
def squeeze(x, dim: str | Sequence[str] | None = None):
"""Remove dimensions of size 1 from an XTensorVariable."""
x = as_xtensor(x)
# drop parameter is ignored in pytensor.xtensor
if drop is not None:
warnings.warn("drop parameter has no effect in pytensor.xtensor", UserWarning)
# dim and axis are mutually exclusive
if dim is not None and axis is not None:
raise ValueError("Cannot specify both `dim` and `axis`")
# if axis is specified, it must be a sequence of ints
if axis is not None:
if not isinstance(axis, Sequence):
axis = [axis]
if not all(isinstance(a, int) for a in axis):
raise ValueError("axis must be an integer or a sequence of integers")
# convert axis to dims
dims = tuple(x.type.dims[i] for i in axis)
# if dim is specified, it must be a string or a sequence of strings
if dim is None:
dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1)
......@@ -461,33 +443,18 @@ class ExpandDims(XOp):
return Apply(self, [x, size], [out])
def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwargs):
def expand_dims(x, dim=None, axis=None, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable."""
x = as_xtensor(x)
# Store original dimensions for axis handling
original_dims = x.type.dims
# Warn if create_index_for_new_dim is used (not supported)
if create_index_for_new_dim is not None:
warnings.warn(
"create_index_for_new_dim=False has no effect in pytensor.xtensor",
UserWarning,
stacklevel=2,
)
if dim is None:
dim = dim_kwargs
elif dim_kwargs:
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")
# Check that dim is Hashable or a sequence of Hashable or dict
if not isinstance(dim, Hashable):
if not isinstance(dim, Sequence | dict):
raise TypeError(f"unhashable type: {type(dim).__name__}")
if not all(isinstance(d, Hashable) for d in dim):
raise TypeError(f"unhashable type in {type(dim).__name__}")
# Normalize to a dimension-size mapping
if isinstance(dim, str):
dims_dict = {dim: 1}
......@@ -496,9 +463,7 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
elif isinstance(dim, dict):
dims_dict = {}
for name, val in dim.items():
if isinstance(val, str):
raise TypeError(f"Dimension size cannot be a string: {val}")
if isinstance(val, Sequence | np.ndarray):
if isinstance(val, list | tuple | np.ndarray):
warnings.warn(
"When a sequence is provided as a dimension size, only its length is used. "
"The actual values (which would be coordinates in xarray) are ignored.",
......
......@@ -687,7 +687,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
drop : bool, optional
If drop=True, drop squeezed coordinates instead of making them scalar.
Ignored by PyTensor.
axis : int or iterable of int, optional
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
Returns
......@@ -695,12 +695,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
XTensorVariable
A new tensor with the specified dimension(s) removed.
"""
return px.shape.squeeze(self, dim, drop, axis)
if axis is not None:
if dim is not None:
raise ValueError("Cannot specify both `dim` and `axis`")
if not isinstance(axis, Sequence):
axis = (axis,)
dim = tuple(self.type.dims[i] for i in axis)
return px.shape.squeeze(self, dim)
def expand_dims(
self,
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
create_index_for_new_dim: bool = True,
create_index_for_new_dim: bool | None = None,
axis: int | Sequence[int] | None = None,
**dim_kwargs,
):
......@@ -714,7 +723,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
create_index_for_new_dim : bool, optional
Ignored by PyTensor
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
......@@ -730,7 +739,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return px.shape.expand_dims(
self,
dim,
create_index_for_new_dim=create_index_for_new_dim,
axis=axis,
**dim_kwargs,
)
......
......@@ -2,6 +2,8 @@ import pytest
pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import re
......
......@@ -3,6 +3,7 @@ import pytest
pytest.importorskip("xarray")
pytest.importorskip("xarray_einstats")
pytestmark = pytest.mark.filterwarnings("error")
import numpy as np
from xarray import DataArray
......
......@@ -2,6 +2,7 @@ import pytest
pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import inspect
......
import pytest
pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import inspect
import re
from copy import deepcopy
import numpy as np
import pytest
import pytensor.tensor.random as ptr
import pytensor.xtensor.random as pxr
......
......@@ -2,6 +2,7 @@ import pytest
pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
......
......@@ -2,6 +2,7 @@ import pytest
pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import re
from itertools import chain, combinations
......@@ -33,9 +34,6 @@ from tests.xtensor.util import (
)
pytest.importorskip("xarray")
def powerset(iterable, min_group_size=0):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
......@@ -322,7 +320,7 @@ def test_squeeze():
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=1))
# Test axis parameter with negative index
y5 = x5.squeeze(axis=-1) # squeeze dimension at index -2 (b)
y5 = x5.squeeze(axis=-2) # squeeze dimension at index -2 (b)
fn5 = xr_function([x5], y5)
x5_test = xr_arange_like(x5)
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=-2))
......@@ -333,12 +331,9 @@ def test_squeeze():
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn6(x2_test), x2_test.squeeze(axis=[1, 2]))
# Test drop parameter warning
# Test drop parameter ignored, but accepted
x7 = xtensor("x7", dims=("a", "b"), shape=(2, 1))
with pytest.warns(
UserWarning, match="drop parameter has no effect in pytensor.xtensor"
):
y7 = x7.squeeze("b", drop=True) # squeeze and drop coordinate
y7 = x7.squeeze("b", drop=True)
fn7 = xr_function([x7], y7)
x7_test = xr_arange_like(x7)
xr_assert_allclose(fn7(x7_test), x7_test.squeeze("b", drop=True))
......@@ -391,7 +386,8 @@ def test_expand_dims():
xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3))
# Test with a dict of name-coord array pairs
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
with pytest.warns(UserWarning, match="only its length is used"):
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test),
......@@ -471,7 +467,7 @@ def test_expand_dims_errors():
# TypeError: unhashable type: 'numpy.ndarray'
# Test with a numpy array as dim (not supported)
with pytest.raises(TypeError, match="unhashable type"):
with pytest.raises(TypeError, match="Invalid type for `dim`"):
y.expand_dims(np.array([1, 2]))
......
......@@ -2,6 +2,7 @@ import pytest
pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import numpy as np
from xarray import DataArray
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论