提交 4413fb47 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Support duplicate indices in AdvancedIncSubtensor1's Numba implementation

上级 8e48b629
......@@ -499,7 +499,6 @@ def numba_funcify_Subtensor(op, node, **kwargs):
@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_IncSubtensor(op, node, **kwargs):
incsubtensor_def_src = create_index_func(
......@@ -515,6 +514,39 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
return numba_njit(incsubtensor_fn)
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
if set_instead_of_inc:
@numba_njit
def advancedincsubtensor1_inplace(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] = val
return x
else:
@numba_njit
def advancedincsubtensor1_inplace(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] += val
return x
if inplace:
return advancedincsubtensor1_inplace
else:
@numba_njit
def advancedincsubtensor1(x, vals, idxs):
x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs)
return advancedincsubtensor1
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
......
......@@ -550,6 +550,11 @@ def test_IncSubtensor(x, y, indices):
at.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 2],),
),
(
at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
at.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 1],),
),
],
)
def test_AdvancedIncSubtensor1(x, y, indices):
......@@ -583,6 +588,14 @@ def test_AdvancedIncSubtensor1(x, y, indices):
at.as_tensor(rng.poisson(size=(2, 4))),
([1, 2], slice(None), [3, 4]),
),
pytest.param(
at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
at.as_tensor(rng.poisson(size=(2, 5))),
([1, 1], [2, 2]),
marks=pytest.mark.xfail(
reason="Duplicate index handling hasn't been implemented, yet."
),
),
],
)
def test_AdvancedIncSubtensor(x, y, indices):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论