提交 6a19a378 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Support more cases of multi-dimensional advanced indexing and updating in Numba

Extends pre-existing rewrite to ravel multiple integer indices, and to place them consecutively. The following cases should now be supported without object mode: * Advanced integer indexing (not mixed with basic or boolean indexing) that do not require broadcasting of indices * Consecutive advanced integer indexing updating (set/inc) (not mixed with basic or boolean indexing) that do not require broadcasting of indices or y.
上级 b721b669
......@@ -150,7 +150,7 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
for adv_idx in adv_idxs
)
# Must be consecutive
and not op.non_contiguous_adv_indexing(node)
and not op.non_consecutive_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted
and (
y is None
......
......@@ -2029,17 +2029,40 @@ def ravel_multidimensional_bool_idx(fgraph, node):
return [copy_stack_trace(node.outputs[0], new_out)]
@node_rewriter(tracks=[AdvancedSubtensor])
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
supported by Numba or by our specialized dispatchers
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
It also handles multiple integer indices, but only if they don't broadcast
x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
"""
op = node.op
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)
if is_inc_subtensor:
x, y, *idxs = node.inputs
# Inc/SetSubtensor is harder to reason about due to y
# We get out if it's broadcasting or if the advanced indices are non-consecutive
if non_consecutive_adv_indexing or (
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
):
return None
else:
x, *idxs = node.inputs
if any(
......@@ -2049,39 +2072,90 @@ def ravel_multidimensional_int_idx(fgraph, node):
)
for idx in idxs
):
# Get out if there are any other advanced indexes or np.newaxis
# Get out if there are any other advanced indices or np.newaxis
return None
int_idxs = [
int_idxs_and_pos = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
]
if len(int_idxs) != 1:
# Get out if there are no or multiple integer idxs
if not int_idxs_and_pos:
return None
[(int_idx_pos, int_idx)] = int_idxs
if int_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
int_idxs_pos, int_idxs = zip(
*int_idxs_and_pos, strict=False
) # strict=False because by definition it's true
first_int_idx_pos = int_idxs_pos[0]
first_int_idx = int_idxs[0]
first_int_idx_bcast = first_int_idx.type.broadcastable
if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
# We don't have a view-only broadcasting operation
# Explicitly broadcasting the indices can incur a memory / copy overhead
return None
raveled_int_idx = int_idx.ravel()
new_idxs = list(idxs)
new_idxs[int_idx_pos] = raveled_int_idx
raveled_subtensor = x[tuple(new_idxs)]
int_idxs_ndim = len(first_int_idx_bcast)
if (
int_idxs_ndim == 0
): # This should be a basic indexing operation, rewrite elsewhere
return None
int_idxs_need_raveling = int_idxs_ndim > 1
if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
return None
# Reshape into correct shape
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
raveled_shape = raveled_subtensor.shape
# Reorder non-consecutive indices
if non_consecutive_adv_indexing:
assert not is_inc_subtensor # Sanity check that we got out if this was the case
# This case works as if all the advanced indices were on the front
transposition = list(int_idxs_pos) + [
i for i in range(len(idxs)) if i not in int_idxs_pos
]
idxs = tuple(idxs[a] for a in transposition)
x = x.transpose(transposition)
first_int_idx_pos = 0
del int_idxs_pos # Make sure they are not wrongly used
# Ravel multidimensional indices
if int_idxs_need_raveling:
idxs = list(idxs)
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
idxs[idx_pos] = int_idx.ravel()
# Index with reordered and/or raveled indices
new_subtensor = x[tuple(idxs)]
if is_inc_subtensor:
y_shape = tuple(y.shape)
y_raveled_shape = (
*y_shape[:first_int_idx_pos],
-1,
*y_shape[first_int_idx_pos + int_idxs_ndim :],
)
y_raveled = y.reshape(y_raveled_shape)
new_out = inc_subtensor(
new_subtensor,
y_raveled,
set_instead_of_inc=op.set_instead_of_inc,
ignore_duplicates=op.ignore_duplicates,
inplace=op.inplace,
)
else:
# Unravel advanced indexing dimensions
raveled_shape = tuple(new_subtensor.shape)
unraveled_shape = (
*raveled_shape[:int_idx_pos],
*int_idx.shape,
*raveled_shape[int_idx_pos + 1 :],
*raveled_shape[:first_int_idx_pos],
*first_int_idx.shape,
*raveled_shape[first_int_idx_pos + 1 :],
)
new_out = raveled_subtensor.reshape(unraveled_shape)
new_out = new_subtensor.reshape(unraveled_shape)
return [copy_stack_trace(node.outputs[0], new_out)]
......@@ -2089,10 +2163,12 @@ optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx,
"numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)
optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__,
ravel_multidimensional_int_idx,
"numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)
import logging
import sys
import warnings
from collections.abc import Callable, Iterable
from itertools import chain, groupby
from textwrap import dedent
......@@ -580,8 +581,8 @@ def group_indices(indices):
return idx_groups
def _non_contiguous_adv_indexing(indices) -> bool:
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing)."""
def _non_consecutive_adv_indexing(indices) -> bool:
"""Check if the advanced indexing is non-consecutive (i.e., split by basic indexing)."""
idx_groups = group_indices(indices)
# This means that there are at least two groups of advanced indexing separated by basic indexing
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])
......@@ -611,7 +612,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
idx_groups = group_indices(indices)
if _non_contiguous_adv_indexing(indices):
if _non_consecutive_adv_indexing(indices):
# In this case NumPy places the advanced index groups in the front of the array
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
idx_groups = sorted(idx_groups, key=lambda x: x[0])
......@@ -2796,10 +2797,17 @@ class AdvancedSubtensor(Op):
@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedSubtensor.non_consecutive_adv_indexing(node)
@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-contiguous,
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.
......@@ -2814,10 +2822,10 @@ class AdvancedSubtensor(Op):
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
True if the advanced indexing is non-consecutive, False otherwise.
"""
_, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)
return _non_consecutive_adv_indexing(idxs)
advanced_subtensor = AdvancedSubtensor()
......@@ -2835,7 +2843,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
if isinstance(batch_idx, TensorVariable)
)
if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)):
if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)):
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
# which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
......@@ -2844,7 +2852,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
# TODO: Implement these internally, so Blockwise is always a safe fallback
if any(not isinstance(idx, TensorVariable) for idx in idxs):
raise NotImplementedError(
"Vectorized AdvancedSubtensor with batched indexes or non-contiguous advanced indexing "
"Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing "
"and slices or newaxis is currently not supported."
)
else:
......@@ -2954,10 +2962,17 @@ class AdvancedIncSubtensor(Op):
@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedIncSubtensor.non_consecutive_adv_indexing(node)
@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-contiguous,
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.
......@@ -2972,10 +2987,10 @@ class AdvancedIncSubtensor(Op):
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
True if the advanced indexing is non-consecutive, False otherwise.
"""
_, _, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)
return _non_consecutive_adv_indexing(idxs)
advanced_inc_subtensor = AdvancedIncSubtensor()
......
......@@ -81,11 +81,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
(np.array([True, False, False])),
False,
),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], [2, 3]),
False,
),
# Single multidimensional indexing (supported after specialization rewrites)
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
......@@ -117,6 +112,12 @@ def test_AdvancedSubtensor1_out_of_bounds():
(slice(2, None), np.eye(3).astype(bool)),
False,
),
# Multiple vector indexing (supported by our dispatcher)
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], [2, 3]),
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None), [1, 2], [3, 4]),
......@@ -127,18 +128,35 @@ def test_AdvancedSubtensor1_out_of_bounds():
([1, 2], [3, 4], [5, 6]),
False,
),
# Non-contiguous vector indexing, only supported in obj mode
# Non-consecutive vector indexing, supported by our dispatcher after rewriting
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]),
True,
False,
),
# Multiple multidimensional integer indexing (supported by our dispatcher)
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [[0, 0], [0, 0]]),
False,
),
(
as_tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))),
(slice(None), [[1, 2], [2, 1]], slice(None), [[0, 0], [0, 0]]),
False,
),
# >1d vector indexing, only supported in obj mode
# Multiple multidimensional indexing with broadcasting, only supported in obj mode
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [0, 0]),
True,
),
# multiple multidimensional integer indexing mixed with basic indexing, only supported in obj mode
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
True,
),
],
)
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
......@@ -297,7 +315,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(4 * 5).reshape(4, 5),
(0, [1, 2, 2, 3]), # Broadcasted vector index
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
True,
False,
True,
......@@ -305,7 +323,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
np.array([-99]), # Broadcasted value
(0, [1, 2, 2, 3]), # Broadcasted vector index
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
True,
False,
True,
......@@ -380,7 +398,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 4)),
([1, 2], slice(None), [3, 4]), # Non-contiguous vector indices
([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices
False,
True,
True,
......@@ -400,15 +418,23 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
np.arange(5),
rng.poisson(size=(2, 2)),
([[1, 2], [2, 3]]), # matrix indices
([[1, 2], [2, 3]]), # matrix index
False,
False,
False,
),
(
np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(2, 2, 2)),
(slice(1, 3), [[1, 2], [2, 3]]), # matrix index, mixed with basic index
False,
False,
False,
False, # Gets converted to AdvancedIncSubtensor1
True, # This is actually supported with the default `ignore_duplicates=False`
),
(
np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(1, 2, 2)),
(slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index
rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts
(slice(1, 3), [[1, 2], [2, 3]]),
False,
True,
True,
......@@ -421,6 +447,14 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(3, 2, 2)),
(slice(None), [[1, 2], [2, 1]], [[2, 3], [0, 0]]), # 2 matrix indices
False,
False,
False,
),
],
)
@pytest.mark.parametrize("inplace", (False, True))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论