提交 afc1a6ca authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix get_canonical_form_slice when lengths are numpy integers

Introduced in f9dfe702
上级 781073b6
......@@ -325,7 +325,7 @@ def get_canonical_form_slice(
and is_step_constant
and is_length_constant
):
assert isinstance(length, int)
assert isinstance(length, int | np.integer)
_start, _stop, _step = slice(start, stop, step).indices(length)
if _start <= _stop and _step >= 1:
return slice(_start, _stop, _step), 1
......
......@@ -154,8 +154,11 @@ class TestGetCanonicalFormSlice:
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
def test_all_integer(self):
res = get_canonical_form_slice(slice(1, 5, 2), 7)
@pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar])
def test_all_integer(self, int_fn):
res = get_canonical_form_slice(
slice(int_fn(1), int_fn(5), int_fn(2)), int_fn(7)
)
assert isinstance(res[0], slice)
assert res[1] == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论