提交 01ac0c7c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba RavelMultiIndex: Fix scalars with clip mode

上级 30d6a746
......@@ -154,7 +154,11 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
stacked_indices[..., i] %= dim_limit
elif mode == "clip":
dim_indices = stacked_indices[..., i]
# Cannot call np.clip on scalars
if vec_indices:
stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1)
else:
stacked_indices[..., i] = max(0, min(dim_indices, dim_limit - 1))
else: # raise
dim_indices = stacked_indices[..., i]
invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i])
......
......@@ -171,6 +171,12 @@ def test_FillDiagonalOffset(a, val, offset):
"raise",
ValueError,
),
(
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
(pt.lvector(), np.array([2, 3, 4])),
"wrap",
None,
),
(
tuple(
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
......@@ -188,6 +194,12 @@ def test_FillDiagonalOffset(a, val, offset):
"wrap",
None,
),
(
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
(pt.lvector(), np.array([2, 3, 4])),
"clip",
None,
),
(
tuple(
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论