提交 6a19a378 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Support more cases of multi-dimensional advanced indexing and updating in Numba

Extends pre-existing rewrite to ravel multiple integer indices, and to place them consecutively. The following cases should now be supported without object mode: * Advanced integer indexing (not mixed with basic or boolean indexing) that do not require broadcasting of indices * Consecutive advanced integer indexing updating (set/inc) (not mixed with basic or boolean indexing) that do not require broadcasting of indices or y.
上级 b721b669
...@@ -150,7 +150,7 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -150,7 +150,7 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
for adv_idx in adv_idxs for adv_idx in adv_idxs
) )
# Must be consecutive # Must be consecutive
and not op.non_contiguous_adv_indexing(node) and not op.non_consecutive_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted # y in set/inc_subtensor cannot be broadcasted
and ( and (
y is None y is None
......
...@@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node): ...@@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
return [copy_stack_trace(node.outputs[0], new_out)] return [copy_stack_trace(node.outputs[0], new_out)]
@node_rewriter(tracks=[AdvancedSubtensor]) @node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_int_idx(fgraph, node): def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba """Convert multidimensional integer indexing into equivalent consecutive vector integer index,
supported by Numba or by our specialized dispatchers
x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
It also handles multiple integer indices, but only if they don't broadcast
x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
""" """
x, *idxs = node.inputs op = node.op
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)
if is_inc_subtensor:
x, y, *idxs = node.inputs
# Inc/SetSubtensor is harder to reason about due to y
# We get out if it's broadcasting or if the advanced indices are non-consecutive
if non_consecutive_adv_indexing or (
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
):
return None
else:
x, *idxs = node.inputs
if any( if any(
( (
...@@ -2049,39 +2072,90 @@ def ravel_multidimensional_int_idx(fgraph, node): ...@@ -2049,39 +2072,90 @@ def ravel_multidimensional_int_idx(fgraph, node):
) )
for idx in idxs for idx in idxs
): ):
# Get out if there are any other advanced indexes or np.newaxis # Get out if there are any other advanced indices or np.newaxis
return None return None
int_idxs = [ int_idxs_and_pos = [
(i, idx) (i, idx)
for i, idx in enumerate(idxs) for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes) if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
] ]
if len(int_idxs) != 1: if not int_idxs_and_pos:
# Get out if there are no or multiple integer idxs
return None return None
[(int_idx_pos, int_idx)] = int_idxs int_idxs_pos, int_idxs = zip(
if int_idx.type.ndim < 2: *int_idxs_and_pos, strict=False
# No need to do anything if it's a vector or scalar, as it's already supported by Numba ) # strict=False because by definition it's true
first_int_idx_pos = int_idxs_pos[0]
first_int_idx = int_idxs[0]
first_int_idx_bcast = first_int_idx.type.broadcastable
if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
# We don't have a view-only broadcasting operation
# Explicitly broadcasting the indices can incur a memory / copy overhead
return None return None
raveled_int_idx = int_idx.ravel() int_idxs_ndim = len(first_int_idx_bcast)
new_idxs = list(idxs) if (
new_idxs[int_idx_pos] = raveled_int_idx int_idxs_ndim == 0
raveled_subtensor = x[tuple(new_idxs)] ): # This should be a basic indexing operation, rewrite elsewhere
return None
# Reshape into correct shape
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing int_idxs_need_raveling = int_idxs_ndim > 1
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
raveled_shape = raveled_subtensor.shape # Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
unraveled_shape = ( return None
*raveled_shape[:int_idx_pos],
*int_idx.shape, # Reorder non-consecutive indices
*raveled_shape[int_idx_pos + 1 :], if non_consecutive_adv_indexing:
) assert not is_inc_subtensor # Sanity check that we got out if this was the case
new_out = raveled_subtensor.reshape(unraveled_shape) # This case works as if all the advanced indices were on the front
transposition = list(int_idxs_pos) + [
i for i in range(len(idxs)) if i not in int_idxs_pos
]
idxs = tuple(idxs[a] for a in transposition)
x = x.transpose(transposition)
first_int_idx_pos = 0
del int_idxs_pos # Make sure they are not wrongly used
# Ravel multidimensional indices
if int_idxs_need_raveling:
idxs = list(idxs)
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
idxs[idx_pos] = int_idx.ravel()
# Index with reordered and/or raveled indices
new_subtensor = x[tuple(idxs)]
if is_inc_subtensor:
y_shape = tuple(y.shape)
y_raveled_shape = (
*y_shape[:first_int_idx_pos],
-1,
*y_shape[first_int_idx_pos + int_idxs_ndim :],
)
y_raveled = y.reshape(y_raveled_shape)
new_out = inc_subtensor(
new_subtensor,
y_raveled,
set_instead_of_inc=op.set_instead_of_inc,
ignore_duplicates=op.ignore_duplicates,
inplace=op.inplace,
)
else:
# Unravel advanced indexing dimensions
raveled_shape = tuple(new_subtensor.shape)
unraveled_shape = (
*raveled_shape[:first_int_idx_pos],
*first_int_idx.shape,
*raveled_shape[first_int_idx_pos + 1 :],
)
new_out = new_subtensor.reshape(unraveled_shape)
return [copy_stack_trace(node.outputs[0], new_out)] return [copy_stack_trace(node.outputs[0], new_out)]
...@@ -2089,10 +2163,12 @@ optdb["specialize"].register( ...@@ -2089,10 +2163,12 @@ optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__, ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx, ravel_multidimensional_bool_idx,
"numba", "numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
) )
optdb["specialize"].register( optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__, ravel_multidimensional_int_idx.__name__,
ravel_multidimensional_int_idx, ravel_multidimensional_int_idx,
"numba", "numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
) )
import logging import logging
import sys import sys
import warnings
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import chain, groupby from itertools import chain, groupby
from textwrap import dedent from textwrap import dedent
...@@ -580,8 +581,8 @@ def group_indices(indices): ...@@ -580,8 +581,8 @@ def group_indices(indices):
return idx_groups return idx_groups
def _non_contiguous_adv_indexing(indices) -> bool: def _non_consecutive_adv_indexing(indices) -> bool:
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing).""" """Check if the advanced indexing is non-consecutive (i.e., split by basic indexing)."""
idx_groups = group_indices(indices) idx_groups = group_indices(indices)
# This means that there are at least two groups of advanced indexing separated by basic indexing # This means that there are at least two groups of advanced indexing separated by basic indexing
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0]) return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])
...@@ -611,7 +612,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ...@@ -611,7 +612,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape)) remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
idx_groups = group_indices(indices) idx_groups = group_indices(indices)
if _non_contiguous_adv_indexing(indices): if _non_consecutive_adv_indexing(indices):
# In this case NumPy places the advanced index groups in the front of the array # In this case NumPy places the advanced index groups in the front of the array
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing # https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
idx_groups = sorted(idx_groups, key=lambda x: x[0]) idx_groups = sorted(idx_groups, key=lambda x: x[0])
...@@ -2796,10 +2797,17 @@ class AdvancedSubtensor(Op): ...@@ -2796,10 +2797,17 @@ class AdvancedSubtensor(Op):
@staticmethod @staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool: def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedSubtensor.non_consecutive_adv_indexing(node)
@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
""" """
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing). Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-contiguous, This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position. output array, regardless of their opriginal position.
...@@ -2814,10 +2822,10 @@ class AdvancedSubtensor(Op): ...@@ -2814,10 +2822,10 @@ class AdvancedSubtensor(Op):
Returns Returns
------- -------
bool bool
True if the advanced indexing is non-contiguous, False otherwise. True if the advanced indexing is non-consecutive, False otherwise.
""" """
_, *idxs = node.inputs _, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs) return _non_consecutive_adv_indexing(idxs)
advanced_subtensor = AdvancedSubtensor() advanced_subtensor = AdvancedSubtensor()
...@@ -2835,7 +2843,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): ...@@ -2835,7 +2843,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
if isinstance(batch_idx, TensorVariable) if isinstance(batch_idx, TensorVariable)
) )
if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)): if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)):
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing # Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
# which would put the indexed results to the left of the batch dimensions! # which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex # TODO: Not all cases must be handled by Blockwise, but the logic is complex
...@@ -2844,7 +2852,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): ...@@ -2844,7 +2852,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
# TODO: Implement these internally, so Blockwise is always a safe fallback # TODO: Implement these internally, so Blockwise is always a safe fallback
if any(not isinstance(idx, TensorVariable) for idx in idxs): if any(not isinstance(idx, TensorVariable) for idx in idxs):
raise NotImplementedError( raise NotImplementedError(
"Vectorized AdvancedSubtensor with batched indexes or non-contiguous advanced indexing " "Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing "
"and slices or newaxis is currently not supported." "and slices or newaxis is currently not supported."
) )
else: else:
...@@ -2954,10 +2962,17 @@ class AdvancedIncSubtensor(Op): ...@@ -2954,10 +2962,17 @@ class AdvancedIncSubtensor(Op):
@staticmethod @staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool: def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedIncSubtensor.non_consecutive_adv_indexing(node)
@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
""" """
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing). Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-contiguous, This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position. output array, regardless of their opriginal position.
...@@ -2972,10 +2987,10 @@ class AdvancedIncSubtensor(Op): ...@@ -2972,10 +2987,10 @@ class AdvancedIncSubtensor(Op):
Returns Returns
------- -------
bool bool
True if the advanced indexing is non-contiguous, False otherwise. True if the advanced indexing is non-consecutive, False otherwise.
""" """
_, _, *idxs = node.inputs _, _, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs) return _non_consecutive_adv_indexing(idxs)
advanced_inc_subtensor = AdvancedIncSubtensor() advanced_inc_subtensor = AdvancedIncSubtensor()
......
...@@ -81,11 +81,6 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -81,11 +81,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
(np.array([True, False, False])), (np.array([True, False, False])),
False, False,
), ),
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], [2, 3]),
False,
),
# Single multidimensional indexing (supported after specialization rewrites) # Single multidimensional indexing (supported after specialization rewrites)
( (
as_tensor(np.arange(3 * 3).reshape((3, 3))), as_tensor(np.arange(3 * 3).reshape((3, 3))),
...@@ -117,6 +112,12 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -117,6 +112,12 @@ def test_AdvancedSubtensor1_out_of_bounds():
(slice(2, None), np.eye(3).astype(bool)), (slice(2, None), np.eye(3).astype(bool)),
False, False,
), ),
# Multiple vector indexing (supported by our dispatcher)
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], [2, 3]),
False,
),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None), [1, 2], [3, 4]), (slice(None), [1, 2], [3, 4]),
...@@ -127,18 +128,35 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -127,18 +128,35 @@ def test_AdvancedSubtensor1_out_of_bounds():
([1, 2], [3, 4], [5, 6]), ([1, 2], [3, 4], [5, 6]),
False, False,
), ),
# Non-contiguous vector indexing, only supported in obj mode # Non-consecutive vector indexing, supported by our dispatcher after rewriting
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]), ([1, 2], slice(None), [3, 4]),
True, False,
),
# Multiple multidimensional integer indexing (supported by our dispatcher)
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [[0, 0], [0, 0]]),
False,
),
(
as_tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))),
(slice(None), [[1, 2], [2, 1]], slice(None), [[0, 0], [0, 0]]),
False,
), ),
# >1d vector indexing, only supported in obj mode # Multiple multidimensional indexing with broadcasting, only supported in obj mode
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [0, 0]), ([[1, 2], [2, 1]], [0, 0]),
True, True,
), ),
# multiple multidimensional integer indexing mixed with basic indexing, only supported in obj mode
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
True,
),
], ],
) )
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed @pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
...@@ -297,7 +315,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -297,7 +315,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(4 * 5).reshape(4, 5), -np.arange(4 * 5).reshape(4, 5),
(0, [1, 2, 2, 3]), # Broadcasted vector index (0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
True, True,
False, False,
True, True,
...@@ -305,7 +323,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -305,7 +323,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
np.array([-99]), # Broadcasted value np.array([-99]), # Broadcasted value
(0, [1, 2, 2, 3]), # Broadcasted vector index (0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
True, True,
False, False,
True, True,
...@@ -380,7 +398,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -380,7 +398,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 4)), rng.poisson(size=(2, 4)),
([1, 2], slice(None), [3, 4]), # Non-contiguous vector indices ([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices
False, False,
True, True,
True, True,
...@@ -400,15 +418,23 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -400,15 +418,23 @@ def test_AdvancedIncSubtensor1(x, y, indices):
( (
np.arange(5), np.arange(5),
rng.poisson(size=(2, 2)), rng.poisson(size=(2, 2)),
([[1, 2], [2, 3]]), # matrix indices ([[1, 2], [2, 3]]), # matrix index
False,
False,
False,
),
(
np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(2, 2, 2)),
(slice(1, 3), [[1, 2], [2, 3]]), # matrix index, mixed with basic index
False,
False,
False, False,
False, # Gets converted to AdvancedIncSubtensor1
True, # This is actually supported with the default `ignore_duplicates=False`
), ),
( (
np.arange(3 * 5).reshape((3, 5)), np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(1, 2, 2)), rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts
(slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index (slice(1, 3), [[1, 2], [2, 3]]),
False, False,
True, True,
True, True,
...@@ -421,6 +447,14 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -421,6 +447,14 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False, False,
False, False,
), ),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(3, 2, 2)),
(slice(None), [[1, 2], [2, 1]], [[2, 3], [0, 0]]), # 2 matrix indices
False,
False,
False,
),
], ],
) )
@pytest.mark.parametrize("inplace", (False, True)) @pytest.mark.parametrize("inplace", (False, True))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论