提交 01ac0c7c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba RavelMultiIndex: Fix scalars with clip mode

上级 30d6a746
...@@ -154,7 +154,11 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs): ...@@ -154,7 +154,11 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
stacked_indices[..., i] %= dim_limit stacked_indices[..., i] %= dim_limit
elif mode == "clip": elif mode == "clip":
dim_indices = stacked_indices[..., i] dim_indices = stacked_indices[..., i]
stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1) # Cannot call np.clip on scalars
if vec_indices:
stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1)
else:
stacked_indices[..., i] = max(0, min(dim_indices, dim_limit - 1))
else: # raise else: # raise
dim_indices = stacked_indices[..., i] dim_indices = stacked_indices[..., i]
invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i]) invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i])
......
...@@ -171,6 +171,12 @@ def test_FillDiagonalOffset(a, val, offset): ...@@ -171,6 +171,12 @@ def test_FillDiagonalOffset(a, val, offset):
"raise", "raise",
ValueError, ValueError,
), ),
(
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
(pt.lvector(), np.array([2, 3, 4])),
"wrap",
None,
),
( (
tuple( tuple(
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) (pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
...@@ -188,6 +194,12 @@ def test_FillDiagonalOffset(a, val, offset): ...@@ -188,6 +194,12 @@ def test_FillDiagonalOffset(a, val, offset):
"wrap", "wrap",
None, None,
), ),
(
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
(pt.lvector(), np.array([2, 3, 4])),
"clip",
None,
),
( (
tuple( tuple(
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) (pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论