提交 5f374dbd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numba AdvancedIncSubtensor1 with broadcasted values

上级 1e96b894
......@@ -604,36 +604,70 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
return numba_njit(incsubtensor_fn, boundscheck=True)
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace_set(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] = val
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
x, vals, idxs = node.inputs
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
if set_instead_of_inc:
if broadcast:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] = core_val
return x
else:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace_inc(x, vals, idxs):
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
for idx, val in zip(idxs, vals):
x[idx] += val
x[idx] = val
return x
else:
if broadcast:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] += core_val
return x
@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
if set_instead_of_inc:
advancedincsubtensor1_inplace = global_numba_func(
advancedincsubtensor1_inplace_set
)
else:
advancedincsubtensor1_inplace = global_numba_func(
advancedincsubtensor1_inplace_inc
)
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
for idx, val in zip(idxs, vals):
x[idx] += val
return x
if inplace:
return global_numba_func(advancedincsubtensor1_inplace)
return advancedincsubtensor1_inplace
else:
@numba_njit
......
......@@ -406,6 +406,7 @@ def test_Subtensor(x, indices):
"x, indices",
[
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)),
],
)
def test_AdvancedSubtensor1(x, indices):
......@@ -498,6 +499,27 @@ def test_IncSubtensor(x, y, indices):
pt.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 1],),
),
# Broadcasting values
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(1, 4, 5))),
([0, 2, 0],),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(5,))),
([0, 2],),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=())),
([2, 0],),
),
(
pt.as_tensor(np.arange(5)),
pt.as_tensor(rng.poisson(size=())),
([2, 0],),
),
],
)
def test_AdvancedIncSubtensor1(x, y, indices):
......@@ -511,11 +533,21 @@ def test_AdvancedIncSubtensor1(x, y, indices):
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
# With symbolic inputs
x_pt = x.type()
out_pt = pt_subtensor.AdvancedIncSubtensor1(inplace=True)(x_pt, y, *indices)
y_pt = y.type()
out_pt = pt_subtensor.AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data])
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data, y.data])
out_pt = pt_subtensor.AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
x_pt, y_pt, *indices
)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data, y.data])
@pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论