提交 4298b761 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Split: Validate sizes

上级 de6aca85
......@@ -124,10 +124,19 @@ def numba_funcify_Join(op, **kwargs):
@register_funcify_default_op_cache_key(Split)
def numba_funcify_Split(op, **kwargs):
@numba_basic.numba_njit
def split(tensor, axis, indices):
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item())
def split(x, axis, sizes):
if (sizes < 0).any():
raise ValueError("Split sizes cannot be negative")
axis = axis.item()
split_indices = np.cumsum(sizes)
if split_indices[-1] != x.shape[axis]:
raise ValueError(
f"Split sizes sum to {split_indices[-1]}; expected {x.shape[axis]}"
)
return np.split(x, split_indices[:-1], axis=axis)
return split
cache_version = 1
return split, cache_version
@register_funcify_default_op_cache_key(ExtractDiag)
......
......@@ -10,6 +10,7 @@ from pytensor.scalar import Add
from tests.link.numba.test_basic import (
compare_numba_and_py,
compare_shape_dtype,
numba_mode,
)
from tests.tensor.test_basic import check_alloc_runtime_broadcast
......@@ -245,6 +246,18 @@ def test_Split_view():
)
def test_split_errors():
x = pt.dvector("x", shape=(5,))
splits = pt.tensor(shape=(3,), dtype="int64")
outs = pt.split(x, splits)
fn = function([x, splits], outs, mode=numba_mode)
test_x = np.zeros((5,))
with pytest.raises(ValueError, match="Split sizes sum to 4; expected 5"):
fn(test_x, np.array([1, 2, 1], dtype="int64"))
with pytest.raises(ValueError, match="Split sizes cannot be negative"):
fn(test_x, np.array([2, 4, -1], dtype="int64"))
@pytest.mark.parametrize(
"val, offset",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论