提交 9bee61d5 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix empty xtensor indexing

上级 18449731
...@@ -153,8 +153,18 @@ class Index(XOp): ...@@ -153,8 +153,18 @@ class Index(XOp):
if len(idxs) > x_ndim: if len(idxs) > x_ndim:
raise IndexError("Too many indices") raise IndexError("Too many indices")
# Remove useless trailing slice(None) indices
# starting at -1 ensures we handle the case of no useful indices correctly with idxs[:0]
last_useful_index = -1
for i, idx in enumerate(idxs):
if isinstance(idx, slice) and idx == slice(None):
continue
last_useful_index = i
# Convert (useful) indices to symbolic variables
idxs = [ idxs = [
as_idx_variable(idx, dim) for idx, dim in zip(idxs, x_dims, strict=False) as_idx_variable(idx, dim)
for idx, dim in zip(idxs[: last_useful_index + 1], x_dims, strict=False)
] ]
for i, idx in enumerate(idxs): for i, idx in enumerate(idxs):
...@@ -174,7 +184,9 @@ class Index(XOp): ...@@ -174,7 +184,9 @@ class Index(XOp):
idx_dim_shape = idx.type.shape[idx_dims.index(idx_dim)] idx_dim_shape = idx.type.shape[idx_dims.index(idx_dim)]
combine_dim_info(idx_dim, idx_dim_shape) combine_dim_info(idx_dim, idx_dim_shape)
for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]): for dim_i, shape_i in zip(
x_dims[last_useful_index + 1 :], x_shape[last_useful_index + 1 :]
):
# Add back any unindexed dimensions # Add back any unindexed dimensions
if dim_i not in out_dims: if dim_i not in out_dims:
# If the dimension was not indexed, we keep it as is # If the dimension was not indexed, we keep it as is
......
...@@ -488,9 +488,8 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -488,9 +488,8 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
) )
indexers = indexers_kwargs indexers = indexers_kwargs
if not indexers: elif indexers is None:
# No-op indexers = {}
return self
if missing_dims not in {"raise", "warn", "ignore"}: if missing_dims not in {"raise", "warn", "ignore"}:
raise ValueError( raise ValueError(
......
...@@ -12,6 +12,7 @@ from xarray import DataArray ...@@ -12,6 +12,7 @@ from xarray import DataArray
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.xtensor import xtensor from pytensor.xtensor import xtensor
from tests.unittest_tools import assert_equal_computations
from tests.xtensor.util import ( from tests.xtensor.util import (
xr_arange_like, xr_arange_like,
xr_assert_allclose, xr_assert_allclose,
...@@ -511,3 +512,33 @@ def test_diff(dim, n): ...@@ -511,3 +512,33 @@ def test_diff(dim, n):
else: else:
expected_res = x_test.diff(dim, n=n) expected_res = x_test.diff(dim, n=n)
xr_assert_allclose(res, expected_res) xr_assert_allclose(res, expected_res)
def test_empty_index():
x = xtensor("x", shape=(5, 5), dims=("a", "b"))
out1 = x[()]
out2 = x[...]
out3 = x.isel({})
out4 = x.isel({"c": 0}, missing_dims="ignore")
assert_equal_computations([out1], [out2])
assert_equal_computations([out1], [out3])
assert_equal_computations([out1], [out4])
fn = xr_function([x], out1)
x_test = xr_random_like(x)
xr_assert_allclose(fn(x_test), x_test)
def test_empty_update_index():
x = xtensor("x", shape=(5, 5), dims=("a", "b"))
out1 = x[()].inc(1)
out2 = x[...].inc(1)
out3 = x.isel({}).inc(1)
out4 = x.isel({"c": 0}, missing_dims="ignore").inc(1)
assert_equal_computations([out1], [out2])
assert_equal_computations([out1], [out3])
assert_equal_computations([out1], [out4])
fn = xr_function([x], out1)
x_test = xr_random_like(x)
xr_assert_allclose(fn(x_test), x_test + 1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论