Unverified 提交 1243d583 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Add xtensor.where as a method and root operation (#1985)

上级 55b00d16
...@@ -2,7 +2,7 @@ import warnings ...@@ -2,7 +2,7 @@ import warnings
import pytensor.xtensor.rewriting import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, math, random, signal from pytensor.xtensor import linalg, math, random, signal
from pytensor.xtensor.math import dot from pytensor.xtensor.math import dot, where
from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like
from pytensor.xtensor.type import ( from pytensor.xtensor.type import (
as_xtensor, as_xtensor,
......
...@@ -696,8 +696,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -696,8 +696,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
Parameters Parameters
---------- ----------
x : XTensorVariable
The input tensor
dim : str or None or iterable of str, optional dim : str or None or iterable of str, optional
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime. (known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
...@@ -758,6 +756,32 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -758,6 +756,32 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
**dim_kwargs, **dim_kwargs,
) )
# Missing value handling
# https://docs.xarray.dev/en/stable/api/dataarray.html#missing-value-handling
def where(self, cond, other=None, drop: bool = False):
"""Filter elements from this object according to a condition.
Parameters
----------
cond : Variable
Locations at which to preserve this object's values.
other: Variable, optional
Value to use for locations in this object where cond is False.
By default, these locations are filled with nan (which may cause upcasting).
drop: bool
Ignored by PyTensor
Returns
-------
XTensorVariable
A tensor with additional dimensions inserted at the front.
"""
if other is None:
other = np.nan
res = px.math.where(cond, self, other)
# xarray puts self dims first
return res.transpose(*self.dims, ...)
# ndarray methods # ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7 # https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max): def clip(self, min, max):
......
...@@ -23,6 +23,7 @@ from pytensor.xtensor.type import ( ...@@ -23,6 +23,7 @@ from pytensor.xtensor.type import (
xtensor_constant, xtensor_constant,
xtensor_shared, xtensor_shared,
) )
from tests.xtensor.util import xr_assert_allclose, xr_function
def test_xtensortype(): def test_xtensortype():
...@@ -231,3 +232,26 @@ def test_isel_missing_dims(): ...@@ -231,3 +232,26 @@ def test_isel_missing_dims():
x.isel(c=0, missing_dims="warn") x.isel(c=0, missing_dims="warn")
x.isel(c=0, missing_dims="ignore").dims == ("a", "b") x.isel(c=0, missing_dims="ignore").dims == ("a", "b")
def test_where():
a = xtensor(dims=("a", "b"))
a_test = DataArray(np.arange(6).reshape(2, 3), dims=a.dims)
# Implicit other
out = a.where(a > 1)
res = xr_function([a], out)(a_test)
expected = a_test.where(a_test > 1)
xr_assert_allclose(res, expected)
# Explicit other
out = a.where(a > 1, 99)
res = xr_function([a], out)(a_test)
expected = a_test.where(a_test > 1, 99)
xr_assert_allclose(res, expected)
# Case that would fail if we didn't transpose
out = a[0].where(a > 1, -1)
res = xr_function([a], out)(a_test)
expected = a_test[0].where(a_test > 1, -1)
xr_assert_allclose(res, expected)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论