提交 74af76f9 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Add boundschecks in numba backend

上级 7caee0e8
...@@ -554,7 +554,7 @@ def numba_funcify_Subtensor(op, node, **kwargs): ...@@ -554,7 +554,7 @@ def numba_funcify_Subtensor(op, node, **kwargs):
subtensor_def_src, "subtensor", {**globals(), **global_env} subtensor_def_src, "subtensor", {**globals(), **global_env}
) )
return numba_njit(subtensor_fn) return numba_njit(subtensor_fn, boundscheck=True)
@numba_funcify.register(IncSubtensor) @numba_funcify.register(IncSubtensor)
...@@ -570,7 +570,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs): ...@@ -570,7 +570,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env} incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
) )
return numba_njit(incsubtensor_fn) return numba_njit(incsubtensor_fn, boundscheck=True)
@numba_funcify.register(AdvancedIncSubtensor1) @numba_funcify.register(AdvancedIncSubtensor1)
...@@ -580,7 +580,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -580,7 +580,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if set_instead_of_inc: if set_instead_of_inc:
@numba_njit @numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs): def advancedincsubtensor1_inplace(x, vals, idxs):
for idx, val in zip(idxs, vals): for idx, val in zip(idxs, vals):
x[idx] = val x[idx] = val
...@@ -588,7 +588,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -588,7 +588,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else: else:
@numba_njit @numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs): def advancedincsubtensor1_inplace(x, vals, idxs):
for idx, val in zip(idxs, vals): for idx, val in zip(idxs, vals):
x[idx] += val x[idx] += val
......
...@@ -373,6 +373,14 @@ def test_AdvancedSubtensor1(x, indices): ...@@ -373,6 +373,14 @@ def test_AdvancedSubtensor1(x, indices):
compare_numba_and_py(out_fg, []) compare_numba_and_py(out_fg, [])
def test_AdvancedSubtensor1_out_of_bounds():
out_at = at_subtensor.advanced_subtensor1(np.arange(3), [4])
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError):
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论