提交 4413fb47 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Support duplicate indices in AdvancedIncSubtensor1's Numba implementation

上级 8e48b629
...@@ -499,7 +499,6 @@ def numba_funcify_Subtensor(op, node, **kwargs): ...@@ -499,7 +499,6 @@ def numba_funcify_Subtensor(op, node, **kwargs):
@numba_funcify.register(IncSubtensor) @numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor) @numba_funcify.register(AdvancedIncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_IncSubtensor(op, node, **kwargs): def numba_funcify_IncSubtensor(op, node, **kwargs):
incsubtensor_def_src = create_index_func( incsubtensor_def_src = create_index_func(
...@@ -515,6 +514,39 @@ def numba_funcify_IncSubtensor(op, node, **kwargs): ...@@ -515,6 +514,39 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
return numba_njit(incsubtensor_fn) return numba_njit(incsubtensor_fn)
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
if set_instead_of_inc:
@numba_njit
def advancedincsubtensor1_inplace(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] = val
return x
else:
@numba_njit
def advancedincsubtensor1_inplace(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] += val
return x
if inplace:
return advancedincsubtensor1_inplace
else:
@numba_njit
def advancedincsubtensor1(x, vals, idxs):
x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs)
return advancedincsubtensor1
@numba_funcify.register(DeepCopyOp) @numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs): def numba_funcify_DeepCopyOp(op, node, **kwargs):
......
...@@ -550,6 +550,11 @@ def test_IncSubtensor(x, y, indices): ...@@ -550,6 +550,11 @@ def test_IncSubtensor(x, y, indices):
at.as_tensor(rng.poisson(size=(2, 4, 5))), at.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 2],), ([1, 2],),
), ),
(
at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
at.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 1],),
),
], ],
) )
def test_AdvancedIncSubtensor1(x, y, indices): def test_AdvancedIncSubtensor1(x, y, indices):
...@@ -583,6 +588,14 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -583,6 +588,14 @@ def test_AdvancedIncSubtensor1(x, y, indices):
at.as_tensor(rng.poisson(size=(2, 4))), at.as_tensor(rng.poisson(size=(2, 4))),
([1, 2], slice(None), [3, 4]), ([1, 2], slice(None), [3, 4]),
), ),
pytest.param(
at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
at.as_tensor(rng.poisson(size=(2, 5))),
([1, 1], [2, 2]),
marks=pytest.mark.xfail(
reason="Duplicate index handling hasn't been implemented, yet."
),
),
], ],
) )
def test_AdvancedIncSubtensor(x, y, indices): def test_AdvancedIncSubtensor(x, y, indices):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论