提交 5f374dbd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numba AdvancedIncSubtensor1 with broadcasted values

上级 1e96b894
...@@ -604,36 +604,70 @@ def numba_funcify_IncSubtensor(op, node, **kwargs): ...@@ -604,36 +604,70 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
return numba_njit(incsubtensor_fn, boundscheck=True) return numba_njit(incsubtensor_fn, boundscheck=True)
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace_set(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] = val
return x
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace_inc(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] += val
return x
@numba_funcify.register(AdvancedIncSubtensor1) @numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc set_instead_of_inc = op.set_instead_of_inc
x, vals, idxs = node.inputs
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
if set_instead_of_inc: if set_instead_of_inc:
advancedincsubtensor1_inplace = global_numba_func( if broadcast:
advancedincsubtensor1_inplace_set
) @numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] = core_val
return x
else:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
for idx, val in zip(idxs, vals):
x[idx] = val
return x
else: else:
advancedincsubtensor1_inplace = global_numba_func( if broadcast:
advancedincsubtensor1_inplace_inc
) @numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] += core_val
return x
else:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
for idx, val in zip(idxs, vals):
x[idx] += val
return x
if inplace: if inplace:
return global_numba_func(advancedincsubtensor1_inplace) return advancedincsubtensor1_inplace
else: else:
@numba_njit @numba_njit
......
...@@ -406,6 +406,7 @@ def test_Subtensor(x, indices): ...@@ -406,6 +406,7 @@ def test_Subtensor(x, indices):
"x, indices", "x, indices",
[ [
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)),
], ],
) )
def test_AdvancedSubtensor1(x, indices): def test_AdvancedSubtensor1(x, indices):
...@@ -498,6 +499,27 @@ def test_IncSubtensor(x, y, indices): ...@@ -498,6 +499,27 @@ def test_IncSubtensor(x, y, indices):
pt.as_tensor(rng.poisson(size=(2, 4, 5))), pt.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 1],), ([1, 1],),
), ),
# Broadcasting values
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(1, 4, 5))),
([0, 2, 0],),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=(5,))),
([0, 2],),
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=())),
([2, 0],),
),
(
pt.as_tensor(np.arange(5)),
pt.as_tensor(rng.poisson(size=())),
([2, 0],),
),
], ],
) )
def test_AdvancedIncSubtensor1(x, y, indices): def test_AdvancedIncSubtensor1(x, y, indices):
...@@ -511,11 +533,21 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -511,11 +533,21 @@ def test_AdvancedIncSubtensor1(x, y, indices):
out_fg = FunctionGraph([], [out_pt]) out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, []) compare_numba_and_py(out_fg, [])
# With symbolic inputs
x_pt = x.type() x_pt = x.type()
out_pt = pt_subtensor.AdvancedIncSubtensor1(inplace=True)(x_pt, y, *indices) y_pt = y.type()
out_pt = pt_subtensor.AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt], [out_pt]) out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data]) compare_numba_and_py(out_fg, [x.data, y.data])
out_pt = pt_subtensor.AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
x_pt, y_pt, *indices
)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data, y.data])
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论