提交 cc4b77a0 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove useless arguments from non xarray API

Make sure we don't issue unexpected warnings
上级 8da7cd76
...@@ -163,6 +163,7 @@ lines-after-imports = 2 ...@@ -163,6 +163,7 @@ lines-after-imports = 2
"tests/link/numba/**/test_*.py" = ["E402"] "tests/link/numba/**/test_*.py" = ["E402"]
"tests/link/pytorch/**/test_*.py" = ["E402"] "tests/link/pytorch/**/test_*.py" = ["E402"]
"tests/link/mlx/**/test_*.py" = ["E402"] "tests/link/mlx/**/test_*.py" = ["E402"]
"tests/xtensor/**/test_*.py" = ["E402"]
......
import typing import typing
import warnings import warnings
from collections.abc import Hashable, Sequence from collections.abc import Sequence
from types import EllipsisType from types import EllipsisType
from typing import Literal from typing import Literal
...@@ -384,28 +384,10 @@ class Squeeze(XOp): ...@@ -384,28 +384,10 @@ class Squeeze(XOp):
return Apply(self, [x], [out]) return Apply(self, [x], [out])
def squeeze(x, dim=None, drop=False, axis=None): def squeeze(x, dim: str | Sequence[str] | None = None):
"""Remove dimensions of size 1 from an XTensorVariable.""" """Remove dimensions of size 1 from an XTensorVariable."""
x = as_xtensor(x) x = as_xtensor(x)
# drop parameter is ignored in pytensor.xtensor
if drop is not None:
warnings.warn("drop parameter has no effect in pytensor.xtensor", UserWarning)
# dim and axis are mutually exclusive
if dim is not None and axis is not None:
raise ValueError("Cannot specify both `dim` and `axis`")
# if axis is specified, it must be a sequence of ints
if axis is not None:
if not isinstance(axis, Sequence):
axis = [axis]
if not all(isinstance(a, int) for a in axis):
raise ValueError("axis must be an integer or a sequence of integers")
# convert axis to dims
dims = tuple(x.type.dims[i] for i in axis)
# if dim is specified, it must be a string or a sequence of strings # if dim is specified, it must be a string or a sequence of strings
if dim is None: if dim is None:
dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1) dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1)
...@@ -461,33 +443,18 @@ class ExpandDims(XOp): ...@@ -461,33 +443,18 @@ class ExpandDims(XOp):
return Apply(self, [x, size], [out]) return Apply(self, [x, size], [out])
def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwargs): def expand_dims(x, dim=None, axis=None, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable.""" """Add one or more new dimensions to an XTensorVariable."""
x = as_xtensor(x) x = as_xtensor(x)
# Store original dimensions for axis handling # Store original dimensions for axis handling
original_dims = x.type.dims original_dims = x.type.dims
# Warn if create_index_for_new_dim is used (not supported)
if create_index_for_new_dim is not None:
warnings.warn(
"create_index_for_new_dim=False has no effect in pytensor.xtensor",
UserWarning,
stacklevel=2,
)
if dim is None: if dim is None:
dim = dim_kwargs dim = dim_kwargs
elif dim_kwargs: elif dim_kwargs:
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`") raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")
# Check that dim is Hashable or a sequence of Hashable or dict
if not isinstance(dim, Hashable):
if not isinstance(dim, Sequence | dict):
raise TypeError(f"unhashable type: {type(dim).__name__}")
if not all(isinstance(d, Hashable) for d in dim):
raise TypeError(f"unhashable type in {type(dim).__name__}")
# Normalize to a dimension-size mapping # Normalize to a dimension-size mapping
if isinstance(dim, str): if isinstance(dim, str):
dims_dict = {dim: 1} dims_dict = {dim: 1}
...@@ -496,9 +463,7 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa ...@@ -496,9 +463,7 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
elif isinstance(dim, dict): elif isinstance(dim, dict):
dims_dict = {} dims_dict = {}
for name, val in dim.items(): for name, val in dim.items():
if isinstance(val, str): if isinstance(val, list | tuple | np.ndarray):
raise TypeError(f"Dimension size cannot be a string: {val}")
if isinstance(val, Sequence | np.ndarray):
warnings.warn( warnings.warn(
"When a sequence is provided as a dimension size, only its length is used. " "When a sequence is provided as a dimension size, only its length is used. "
"The actual values (which would be coordinates in xarray) are ignored.", "The actual values (which would be coordinates in xarray) are ignored.",
......
...@@ -687,7 +687,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -687,7 +687,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime. (known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
drop : bool, optional drop : bool, optional
If drop=True, drop squeezed coordinates instead of making them scalar. Ignored by PyTensor.
axis : int or iterable of int, optional axis : int or iterable of int, optional
The axis(es) to remove. If None, all dimensions of size 1 will be removed. The axis(es) to remove. If None, all dimensions of size 1 will be removed.
Returns Returns
...@@ -695,12 +695,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -695,12 +695,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
XTensorVariable XTensorVariable
A new tensor with the specified dimension(s) removed. A new tensor with the specified dimension(s) removed.
""" """
return px.shape.squeeze(self, dim, drop, axis) if axis is not None:
if dim is not None:
raise ValueError("Cannot specify both `dim` and `axis`")
if not isinstance(axis, Sequence):
axis = (axis,)
dim = tuple(self.type.dims[i] for i in axis)
return px.shape.squeeze(self, dim)
def expand_dims( def expand_dims(
self, self,
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None, dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
create_index_for_new_dim: bool = True, create_index_for_new_dim: bool | None = None,
axis: int | Sequence[int] | None = None, axis: int | Sequence[int] | None = None,
**dim_kwargs, **dim_kwargs,
): ):
...@@ -714,7 +723,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -714,7 +723,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
- int: the new size - int: the new size
- sequence: coordinates (length determines size) - sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True create_index_for_new_dim : bool, optional
Ignored by PyTensor Ignored by PyTensor
axis : int | Sequence[int] | None, default: None axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s). Not implemented yet. In xarray, specifies where to insert the new dimension(s).
...@@ -730,7 +739,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -730,7 +739,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return px.shape.expand_dims( return px.shape.expand_dims(
self, self,
dim, dim,
create_index_for_new_dim=create_index_for_new_dim,
axis=axis, axis=axis,
**dim_kwargs, **dim_kwargs,
) )
......
...@@ -2,6 +2,8 @@ import pytest ...@@ -2,6 +2,8 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import re import re
......
...@@ -3,6 +3,7 @@ import pytest ...@@ -3,6 +3,7 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
pytest.importorskip("xarray_einstats") pytest.importorskip("xarray_einstats")
pytestmark = pytest.mark.filterwarnings("error")
import numpy as np import numpy as np
from xarray import DataArray from xarray import DataArray
......
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import inspect import inspect
......
import pytest
pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import inspect import inspect
import re import re
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
import pytest
import pytensor.tensor.random as ptr import pytensor.tensor.random as ptr
import pytensor.xtensor.random as pxr import pytensor.xtensor.random as pxr
......
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
......
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import re import re
from itertools import chain, combinations from itertools import chain, combinations
...@@ -33,9 +34,6 @@ from tests.xtensor.util import ( ...@@ -33,9 +34,6 @@ from tests.xtensor.util import (
) )
pytest.importorskip("xarray")
def powerset(iterable, min_group_size=0): def powerset(iterable, min_group_size=0):
"Subsequences of the iterable from shortest to longest." "Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) # powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
...@@ -322,7 +320,7 @@ def test_squeeze(): ...@@ -322,7 +320,7 @@ def test_squeeze():
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=1)) xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=1))
# Test axis parameter with negative index # Test axis parameter with negative index
y5 = x5.squeeze(axis=-1) # squeeze dimension at index -2 (b) y5 = x5.squeeze(axis=-2) # squeeze dimension at index -2 (b)
fn5 = xr_function([x5], y5) fn5 = xr_function([x5], y5)
x5_test = xr_arange_like(x5) x5_test = xr_arange_like(x5)
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=-2)) xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=-2))
...@@ -333,12 +331,9 @@ def test_squeeze(): ...@@ -333,12 +331,9 @@ def test_squeeze():
x2_test = xr_arange_like(x2) x2_test = xr_arange_like(x2)
xr_assert_allclose(fn6(x2_test), x2_test.squeeze(axis=[1, 2])) xr_assert_allclose(fn6(x2_test), x2_test.squeeze(axis=[1, 2]))
# Test drop parameter warning # Test drop parameter ignored, but accepted
x7 = xtensor("x7", dims=("a", "b"), shape=(2, 1)) x7 = xtensor("x7", dims=("a", "b"), shape=(2, 1))
with pytest.warns( y7 = x7.squeeze("b", drop=True)
UserWarning, match="drop parameter has no effect in pytensor.xtensor"
):
y7 = x7.squeeze("b", drop=True) # squeeze and drop coordinate
fn7 = xr_function([x7], y7) fn7 = xr_function([x7], y7)
x7_test = xr_arange_like(x7) x7_test = xr_arange_like(x7)
xr_assert_allclose(fn7(x7_test), x7_test.squeeze("b", drop=True)) xr_assert_allclose(fn7(x7_test), x7_test.squeeze("b", drop=True))
...@@ -391,7 +386,8 @@ def test_expand_dims(): ...@@ -391,7 +386,8 @@ def test_expand_dims():
xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3)) xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3))
# Test with a dict of name-coord array pairs # Test with a dict of name-coord array pairs
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}) with pytest.warns(UserWarning, match="only its length is used"):
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
fn = xr_function([x], y) fn = xr_function([x], y)
xr_assert_allclose( xr_assert_allclose(
fn(x_test), fn(x_test),
...@@ -471,7 +467,7 @@ def test_expand_dims_errors(): ...@@ -471,7 +467,7 @@ def test_expand_dims_errors():
# TypeError: unhashable type: 'numpy.ndarray' # TypeError: unhashable type: 'numpy.ndarray'
# Test with a numpy array as dim (not supported) # Test with a numpy array as dim (not supported)
with pytest.raises(TypeError, match="unhashable type"): with pytest.raises(TypeError, match="Invalid type for `dim`"):
y.expand_dims(np.array([1, 2])) y.expand_dims(np.array([1, 2]))
......
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import numpy as np import numpy as np
from xarray import DataArray from xarray import DataArray
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论