Unverified 提交 0824dba8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Support multidimensional boolean set/inc_subtensor in Numba via rewrite (#1108)

* Remove opinionated message about ignore_duplicates. Setting to false can lead to slower code on C/Numba backend which don't support np.add.at natively. * Setting to false can lead to slower code on C/Numba backend which don't support np.add.at natively. Support multidimensional boolean set/inc_subtensor in Numba via rewrite
上级 9dad122f
...@@ -249,7 +249,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): ...@@ -249,7 +249,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
This is only done when there's a single vector index. This is only done when there's a single vector index.
""" """
if not isinstance(node.op, AdvancedIncSubtensor) or node.op.ignore_duplicates: if node.op.ignore_duplicates:
# `AdvancedIncSubtensor1` does not ignore duplicate index values # `AdvancedIncSubtensor1` does not ignore duplicate index values
return return
...@@ -1967,19 +1967,26 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node): ...@@ -1967,19 +1967,26 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
return new_out return new_out
@node_rewriter(tracks=[AdvancedSubtensor]) @node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_bool_idx(fgraph, node): def ravel_multidimensional_bool_idx(fgraph, node):
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba """Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()] x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
""" """
x, *idxs = node.inputs if isinstance(node.op, AdvancedSubtensor):
x, *idxs = node.inputs
else:
x, y, *idxs = node.inputs
if any( if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int") (
(isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int"))
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs for idx in idxs
): ):
# Get out if there are any other advanced indexes # Get out if there are any other advanced indexes or np.newaxis
return None return None
bool_idxs = [ bool_idxs = [
...@@ -2007,7 +2014,16 @@ def ravel_multidimensional_bool_idx(fgraph, node): ...@@ -2007,7 +2014,16 @@ def ravel_multidimensional_bool_idx(fgraph, node):
new_idxs = list(idxs) new_idxs = list(idxs)
new_idxs[bool_idx_pos] = raveled_bool_idx new_idxs[bool_idx_pos] = raveled_bool_idx
return [raveled_x[tuple(new_idxs)]] if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(raveled_x, *new_idxs)
else:
# The dimensions of y that correspond to the boolean indices
# must already be raveled in the original graph, so we don't need to do anything to it
new_out = node.op(raveled_x, y, *new_idxs)
# But we must reshape the output to math the original shape
new_out = new_out.reshape(x_shape)
return [copy_stack_trace(node.outputs[0], new_out)]
@node_rewriter(tracks=[AdvancedSubtensor]) @node_rewriter(tracks=[AdvancedSubtensor])
...@@ -2024,10 +2040,13 @@ def ravel_multidimensional_int_idx(fgraph, node): ...@@ -2024,10 +2040,13 @@ def ravel_multidimensional_int_idx(fgraph, node):
x, *idxs = node.inputs x, *idxs = node.inputs
if any( if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool") (
(isinstance(idx.type, TensorType) and idx.type.dtype == "bool")
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs for idx in idxs
): ):
# Get out if there are any other advanced indexes # Get out if there are any other advanced indexes or np.newaxis
return None return None
int_idxs = [ int_idxs = [
...@@ -2059,7 +2078,8 @@ def ravel_multidimensional_int_idx(fgraph, node): ...@@ -2059,7 +2078,8 @@ def ravel_multidimensional_int_idx(fgraph, node):
*int_idx.shape, *int_idx.shape,
*raveled_shape[int_idx_pos + 1 :], *raveled_shape[int_idx_pos + 1 :],
) )
return [raveled_subtensor.reshape(unraveled_shape)] new_out = raveled_subtensor.reshape(unraveled_shape)
return [copy_stack_trace(node.outputs[0], new_out)]
optdb["specialize"].register( optdb["specialize"].register(
......
...@@ -1456,11 +1456,8 @@ def inc_subtensor( ...@@ -1456,11 +1456,8 @@ def inc_subtensor(
views; if they overlap, the result of this `Op` will generally be views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``. incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates ignore_duplicates
This determines whether or not ``x[indices] += y`` is used or This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``. When the special duplicates handling of ``np.add.at(x, indices, y)``.
``np.add.at`` isn't required, setting this option to ``True``
(i.e. using ``x[indices] += y``) can resulting in faster compiled
graphs.
Examples Examples
-------- --------
......
...@@ -334,8 +334,19 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -334,8 +334,19 @@ def test_AdvancedIncSubtensor1(x, y, indices):
-np.arange(3), -np.arange(3),
(np.eye(3).astype(bool)), # Boolean index (np.eye(3).astype(bool)), # Boolean index
False, False,
True, False,
True, False,
),
(
np.arange(3 * 3 * 5).reshape((3, 3, 5)),
rng.poisson(size=(3, 2)),
(
np.eye(3).astype(bool),
slice(-2, None),
), # Boolean index, mixed with basic index
False,
False,
False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -394,10 +405,18 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -394,10 +405,18 @@ def test_AdvancedIncSubtensor1(x, y, indices):
rng.poisson(size=(2, 2)), rng.poisson(size=(2, 2)),
([[1, 2], [2, 3]]), # matrix indices ([[1, 2], [2, 3]]), # matrix indices
False, False,
False, # Gets converted to AdvancedIncSubtensor1
True, # This is actually supported with the default `ignore_duplicates=False`
),
(
np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(1, 2, 2)),
(slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index
False,
True, True,
True, True,
), ),
pytest.param( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 5)), rng.poisson(size=(2, 5)),
([1, 1], [2, 2]), # Repeated indices ([1, 1], [2, 2]), # Repeated indices
...@@ -418,6 +437,9 @@ def test_AdvancedIncSubtensor( ...@@ -418,6 +437,9 @@ def test_AdvancedIncSubtensor(
inc_requires_objmode, inc_requires_objmode,
inplace, inplace,
): ):
# Need rewrite to support certain forms of advanced indexing without object mode
mode = numba_mode.including("specialize")
x_pt = pt.as_tensor(x).type("x") x_pt = pt.as_tensor(x).type("x")
y_pt = pt.as_tensor(y).type("y") y_pt = pt.as_tensor(y).type("y")
...@@ -432,7 +454,7 @@ def test_AdvancedIncSubtensor( ...@@ -432,7 +454,7 @@ def test_AdvancedIncSubtensor(
if set_requires_objmode if set_requires_objmode
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y]) fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)
if inplace: if inplace:
# Test updates inplace # Test updates inplace
...@@ -452,7 +474,7 @@ def test_AdvancedIncSubtensor( ...@@ -452,7 +474,7 @@ def test_AdvancedIncSubtensor(
if inc_requires_objmode if inc_requires_objmode
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y]) fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)
if inplace: if inplace:
# Test updates inplace # Test updates inplace
x_orig = x.copy() x_orig = x.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论