提交 e854b4d2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix isel `missing_dims="ignore"`

上级 9e46e6b4
...@@ -505,7 +505,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -505,7 +505,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
raise TypeError("Ellipsis (...) is an invalid labeled index") raise TypeError("Ellipsis (...) is an invalid labeled index")
try: try:
indices[dims.index(key)] = idx indices[dims.index(key)] = idx
except IndexError: except ValueError:
if missing_dims == "raise": if missing_dims == "raise":
raise ValueError( raise ValueError(
f"Dimension {key} does not exist. Expected one of {dims}" f"Dimension {key} does not exist. Expected one of {dims}"
......
import re
import pytest import pytest
...@@ -148,3 +150,21 @@ def test_minimum_compile(): ...@@ -148,3 +150,21 @@ def test_minimum_compile():
minimum_mode = Mode(linker="py", optimizer="minimum_compile") minimum_mode = Mode(linker="py", optimizer="minimum_compile")
result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode) result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode)
np.testing.assert_array_equal(result, np.ones((3, 2))) np.testing.assert_array_equal(result, np.ones((3, 2)))
def test_isel_missing_dims():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
# Check valid case works
assert x.isel(b=0).dims == ("a",)
with pytest.raises(ValueError):
x.isel(c=0)
with pytest.warns(
UserWarning,
match=re.escape("Dimension c does not exist. Expected one of ('a', 'b')"),
):
x.isel(c=0, missing_dims="warn")
x.isel(c=0, missing_dims="ignore").dims == ("a", "b")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论