Unverified 提交 1243d583 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Add xtensor.where as a method and root operation (#1985)

上级 55b00d16
......@@ -2,7 +2,7 @@ import warnings
import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, math, random, signal
from pytensor.xtensor.math import dot
from pytensor.xtensor.math import dot, where
from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like
from pytensor.xtensor.type import (
as_xtensor,
......
......@@ -696,8 +696,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
Parameters
----------
x : XTensorVariable
The input tensor
dim : str or None or iterable of str, optional
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
......@@ -758,6 +756,32 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
**dim_kwargs,
)
# Missing value handling
# https://docs.xarray.dev/en/stable/api/dataarray.html#missing-value-handling
def where(self, cond, other=None, drop: bool = False):
"""Filter elements from this object according to a condition.
Parameters
----------
cond : Variable
Locations at which to preserve this object's values.
other: Variable, optional
Value to use for locations in this object where cond is False.
By default, these locations are filled with nan (which may cause upcasting).
drop: bool
Ignored by PyTensor
Returns
-------
XTensorVariable
A tensor with additional dimensions inserted at the front.
"""
if other is None:
other = np.nan
res = px.math.where(cond, self, other)
# xarray puts self dims first
return res.transpose(*self.dims, ...)
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
......
......@@ -23,6 +23,7 @@ from pytensor.xtensor.type import (
xtensor_constant,
xtensor_shared,
)
from tests.xtensor.util import xr_assert_allclose, xr_function
def test_xtensortype():
......@@ -231,3 +232,26 @@ def test_isel_missing_dims():
x.isel(c=0, missing_dims="warn")
x.isel(c=0, missing_dims="ignore").dims == ("a", "b")
def test_where():
a = xtensor(dims=("a", "b"))
a_test = DataArray(np.arange(6).reshape(2, 3), dims=a.dims)
# Implicit other
out = a.where(a > 1)
res = xr_function([a], out)(a_test)
expected = a_test.where(a_test > 1)
xr_assert_allclose(res, expected)
# Explicit other
out = a.where(a > 1, 99)
res = xr_function([a], out)(a_test)
expected = a_test.where(a_test > 1, 99)
xr_assert_allclose(res, expected)
# Case that would fail if we didn't transpose
out = a[0].where(a > 1, -1)
res = xr_function([a], out)(a_test)
expected = a_test[0].where(a_test > 1, -1)
xr_assert_allclose(res, expected)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论