提交 ae388840 authored 作者: Allen Downey's avatar Allen Downey 提交者: Ricardo Vieira

Add full_like, ones_like, and zeros_like for XTensorVariable (#1514)

上级 815671d5
......@@ -3,7 +3,7 @@ import warnings
import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, random
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import broadcast, concat
from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
......
......@@ -13,6 +13,7 @@ from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.utils import get_static_shape_from_size_variables
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.math import cast, second
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
from pytensor.xtensor.vectorization import combine_dims_and_shape
......@@ -565,3 +566,100 @@ def broadcast(
raise TypeError(f"exclude must be None, str, or Sequence, got {type(exclude)}")
# xarray broadcast always returns a tuple, even if there's only one tensor
return tuple(Broadcast(exclude=exclude)(*args, return_list=True)) # type: ignore
def full_like(x, fill_value, dtype=None):
"""Create a new XTensorVariable with the same shape and dimensions, filled with a specified value.
Parameters
----------
x : XTensorVariable
The tensor to fill.
fill_value : scalar or XTensorVariable
The value to fill the new tensor with.
dtype : str or np.dtype, optional
The data type of the new tensor. If None, uses the dtype of the input tensor.
Returns
-------
XTensorVariable
A new tensor with the same shape and dimensions as self, filled with fill_value.
Examples
--------
>>> from pytensor.xtensor import xtensor, full_like
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
>>> y = full_like(x, 5.0)
>>> assert y.dims == ("a", "b")
>>> assert y.type.shape == (2, 3)
"""
x = as_xtensor(x)
fill_value = as_xtensor(fill_value)
# Check that fill_value is a scalar (ndim=0)
if fill_value.type.ndim != 0:
raise ValueError(
f"fill_value must be a scalar, got ndim={fill_value.type.ndim}"
)
# Handle dtype conversion
if dtype is not None:
# If dtype is specified, cast the fill_value to that dtype
fill_value = cast(fill_value, dtype)
else:
# If dtype is None, cast the fill_value to the input tensor's dtype
# This matches xarray's behavior where it preserves the original dtype
fill_value = cast(fill_value, x.type.dtype)
# Use the xtensor second function
return second(x, fill_value)
def ones_like(x, dtype=None):
"""Create a new XTensorVariable with the same shape and dimensions, filled with ones.
Parameters
----------
x : XTensorVariable
The tensor to fill.
dtype : str or np.dtype, optional
The data type of the new tensor. If None, uses the dtype of the input tensor.
Returns:
XTensorVariable
A new tensor with the same shape and dimensions as self, filled with ones.
Examples
--------
>>> from pytensor.xtensor import xtensor, full_like
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
>>> y = ones_like(x)
>>> assert y.dims == ("a", "b")
>>> assert y.type.shape == (2, 3)
"""
return full_like(x, 1.0, dtype=dtype)
def zeros_like(x, dtype=None):
"""Create a new XTensorVariable with the same shape and dimensions, filled with zeros.
Parameters
----------
x : XTensorVariable
The tensor to fill.
dtype : str or np.dtype, optional
The data type of the new tensor. If None, uses the dtype of the input tensor.
Returns:
XTensorVariable
A new tensor with the same shape and dimensions as self, filled with zeros.
Examples
--------
>>> from pytensor.xtensor import xtensor, full_like
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
>>> y = zeros_like(x)
>>> assert y.dims == ("a", "b")
>>> assert y.type.shape == (2, 3)
"""
return full_like(x, 0.0, dtype=dtype)
......@@ -11,13 +11,19 @@ import numpy as np
from xarray import DataArray
from xarray import broadcast as xr_broadcast
from xarray import concat as xr_concat
from xarray import full_like as xr_full_like
from xarray import ones_like as xr_ones_like
from xarray import zeros_like as xr_zeros_like
from pytensor.tensor import scalar
from pytensor.xtensor.shape import (
broadcast,
concat,
full_like,
ones_like,
stack,
unstack,
zeros_like,
)
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import (
......@@ -633,3 +639,148 @@ class TestBroadcast:
]
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
def test_full_like():
"""Test full_like function, comparing with xarray's full_like."""
# Basic functionality with scalar fill_value
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
x_test = xr_arange_like(x)
y1 = full_like(x, 5.0)
fn1 = xr_function([x], y1)
result1 = fn1(x_test)
expected1 = xr_full_like(x_test, 5.0)
xr_assert_allclose(result1, expected1, check_dtype=True)
# Other dtypes
x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32")
x_3d_test = xr_arange_like(x_3d)
y7 = full_like(x_3d, -1.0)
fn7 = xr_function([x_3d], y7)
result7 = fn7(x_3d_test)
expected7 = xr_full_like(x_3d_test, -1.0)
xr_assert_allclose(result7, expected7, check_dtype=True)
# Integer dtype
y3 = full_like(x, 5.0, dtype="int32")
fn3 = xr_function([x], y3)
result3 = fn3(x_test)
expected3 = xr_full_like(x_test, 5.0, dtype="int32")
xr_assert_allclose(result3, expected3, check_dtype=True)
# Different fill_value types
y4 = full_like(x, np.array(3.14))
fn4 = xr_function([x], y4)
result4 = fn4(x_test)
expected4 = xr_full_like(x_test, 3.14)
xr_assert_allclose(result4, expected4, check_dtype=True)
# Integer input with float fill_value
x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32")
x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b"))
y5 = full_like(x_int, 2.5)
fn5 = xr_function([x_int], y5)
result5 = fn5(x_int_test)
expected5 = xr_full_like(x_int_test, 2.5)
xr_assert_allclose(result5, expected5, check_dtype=True)
# Symbolic shapes
x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3))
x_sym_test = DataArray(
np.arange(6, dtype=x_sym.type.dtype).reshape(2, 3), dims=("a", "b")
)
y6 = full_like(x_sym, 7.0)
fn6 = xr_function([x_sym], y6)
result6 = fn6(x_sym_test)
expected6 = xr_full_like(x_sym_test, 7.0)
xr_assert_allclose(result6, expected6, check_dtype=True)
# Boolean dtype
x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool")
x_bool_test = DataArray(
np.array([[True, False, True], [False, True, False]]), dims=("a", "b")
)
y8 = full_like(x_bool, True)
fn8 = xr_function([x_bool], y8)
result8 = fn8(x_bool_test)
expected8 = xr_full_like(x_bool_test, True)
xr_assert_allclose(result8, expected8, check_dtype=True)
# Complex dtype
x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64")
x_complex_test = DataArray(
np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b")
)
y9 = full_like(x_complex, 1 + 2j)
fn9 = xr_function([x_complex], y9)
result9 = fn9(x_complex_test)
expected9 = xr_full_like(x_complex_test, 1 + 2j)
xr_assert_allclose(result9, expected9, check_dtype=True)
# Symbolic fill value
x_sym_fill = xtensor("x_sym_fill", dims=("a", "b"), shape=(2, 3), dtype="float64")
fill_val = xtensor("fill_val", dims=(), shape=(), dtype="float64")
x_sym_fill_test = xr_arange_like(x_sym_fill)
fill_val_test = DataArray(3.14, dims=())
y10 = full_like(x_sym_fill, fill_val)
fn10 = xr_function([x_sym_fill, fill_val], y10)
result10 = fn10(x_sym_fill_test, fill_val_test)
expected10 = xr_full_like(x_sym_fill_test, 3.14)
xr_assert_allclose(result10, expected10, check_dtype=True)
# Test dtype conversion to bool when neither input nor fill_value are bool
x_float = xtensor("x_float", dims=("a", "b"), shape=(2, 3), dtype="float64")
x_float_test = xr_arange_like(x_float)
y11 = full_like(x_float, 5.0, dtype="bool")
fn11 = xr_function([x_float], y11)
result11 = fn11(x_float_test)
expected11 = xr_full_like(x_float_test, 5.0, dtype="bool")
xr_assert_allclose(result11, expected11, check_dtype=True)
# Verify the result is actually boolean
assert result11.dtype == "bool"
assert expected11.dtype == "bool"
def test_full_like_errors():
"""Test full_like function errors."""
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
x_test = xr_arange_like(x)
with pytest.raises(ValueError, match="fill_value must be a scalar"):
full_like(x, x_test)
def test_ones_like():
"""Test ones_like function, comparing with xarray's ones_like."""
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
x_test = xr_arange_like(x)
y1 = ones_like(x)
fn1 = xr_function([x], y1)
result1 = fn1(x_test)
expected1 = xr_ones_like(x_test)
xr_assert_allclose(result1, expected1)
assert result1.dtype == expected1.dtype
def test_zeros_like():
"""Test zeros_like function, comparing with xarray's zeros_like."""
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
x_test = xr_arange_like(x)
y1 = zeros_like(x)
fn1 = xr_function([x], y1)
result1 = fn1(x_test)
expected1 = xr_zeros_like(x_test)
xr_assert_allclose(result1, expected1)
assert result1.dtype == expected1.dtype
......@@ -37,11 +37,30 @@ def xr_function(*args, **kwargs):
return xfn
def xr_assert_allclose(x, y, *args, **kwargs):
# Assert that two xarray DataArrays are close, ignoring coordinates
def xr_assert_allclose(x, y, check_dtype=False, *args, **kwargs):
"""Assert that two xarray DataArrays are close, ignoring coordinates.
Mostly a wrapper around xarray.testing.assert_allclose,
but with the option to check the dtype.
Parameters
----------
x : xarray.DataArray
The first xarray DataArray to compare.
y : xarray.DataArray
The second xarray DataArray to compare.
check_dtype : bool, optional
If True, check that the dtype of the two DataArrays is the same.
*args :
Additional arguments to pass to xarray.testing.assert_allclose.
**kwargs :
Additional keyword arguments to pass to xarray.testing.assert_allclose.
"""
x = x.drop_vars(x.coords)
y = y.drop_vars(y.coords)
assert_allclose(x, y, *args, **kwargs)
if check_dtype:
assert x.dtype == y.dtype
def xr_arange_like(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论