提交 d8501d14 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba AdvancedIndexing: Complete support for integer (and mixed basic) advanced indexing

When default `ignore_updates=True` for inc_subtensor, and boolean indices were rewritten during specialize
上级 fe10f960
......@@ -83,7 +83,7 @@ from pytensor.tensor.subtensor import (
inc_subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -1744,205 +1744,45 @@ def local_blockwise_inc_subtensor(fgraph, node):
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_bool_idx(fgraph, node):
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
def bool_idx_to_nonzero(fgraph, node):
"""Convert boolean indexing into equivalent vector boolean index, supported by our dispatch
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
x[1:, eye(3, dtype=bool), 1:] -> x[1:, *eye(3).nonzero()]
"""
if isinstance(node.op, AdvancedSubtensor):
x, *idxs = node.inputs
else:
x, y, *idxs = node.inputs
if any(
(
(isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes)
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indexes or np.newaxis
return None
bool_idxs = [
(i, idx)
bool_pos = {
i
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
]
if len(bool_idxs) != 1:
# Get out if there are no or multiple boolean idxs
return None
}
[(bool_idx_pos, bool_idx)] = bool_idxs
bool_idx_ndim = bool_idx.type.ndim
if bool_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
if not bool_pos:
return None
x_shape = x.shape
raveled_x = x.reshape(
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :])
)
raveled_bool_idx = bool_idx.ravel()
new_idxs = list(idxs)
new_idxs[bool_idx_pos] = raveled_bool_idx
if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(raveled_x, *new_idxs)
new_idxs = []
for i, idx in enumerate(idxs):
if i in bool_pos:
new_idxs.extend(idx.nonzero())
else:
# The dimensions of y that correspond to the boolean indices
# must already be raveled in the original graph, so we don't need to do anything to it
new_out = node.op(raveled_x, y, *new_idxs)
# But we must reshape the output to math the original shape
new_out = new_out.reshape(x_shape)
return [copy_stack_trace(node.outputs[0], new_out)]
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
supported by Numba or by our specialized dispatchers
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
It also handles multiple integer indices, but only if they don't broadcast
x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
"""
op = node.op
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)
if is_inc_subtensor:
x, y, *idxs = node.inputs
# Inc/SetSubtensor is harder to reason about due to y
# We get out if it's broadcasting or if the advanced indices are non-consecutive
if non_consecutive_adv_indexing or (
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
):
return None
else:
x, *idxs = node.inputs
if any(
(
(isinstance(idx.type, TensorType) and idx.type.dtype == "bool")
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indices or np.newaxis
return None
int_idxs_and_pos = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
]
if not int_idxs_and_pos:
return None
int_idxs_pos, int_idxs = zip(
*int_idxs_and_pos, strict=False
) # strict=False because by definition it's true
first_int_idx_pos = int_idxs_pos[0]
first_int_idx = int_idxs[0]
first_int_idx_bcast = first_int_idx.type.broadcastable
if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
# We don't have a view-only broadcasting operation
# Explicitly broadcasting the indices can incur a memory / copy overhead
return None
int_idxs_ndim = len(first_int_idx_bcast)
if (
int_idxs_ndim == 0
): # This should be a basic indexing operation, rewrite elsewhere
return None
int_idxs_need_raveling = int_idxs_ndim > 1
if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
return None
# Reorder non-consecutive indices
if non_consecutive_adv_indexing:
assert not is_inc_subtensor # Sanity check that we got out if this was the case
# This case works as if all the advanced indices were on the front
transposition = list(int_idxs_pos) + [
i for i in range(len(idxs)) if i not in int_idxs_pos
]
idxs = tuple(idxs[a] for a in transposition)
x = x.transpose(transposition)
first_int_idx_pos = 0
del int_idxs_pos # Make sure they are not wrongly used
# Ravel multidimensional indices
if int_idxs_need_raveling:
idxs = list(idxs)
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
idxs[idx_pos] = int_idx.ravel()
# Index with reordered and/or raveled indices
new_subtensor = x[tuple(idxs)]
if is_inc_subtensor:
y_shape = tuple(y.shape)
y_raveled_shape = (
*y_shape[:first_int_idx_pos],
-1,
*y_shape[first_int_idx_pos + int_idxs_ndim :],
)
y_raveled = y.reshape(y_raveled_shape)
new_out = inc_subtensor(
new_subtensor,
y_raveled,
set_instead_of_inc=op.set_instead_of_inc,
ignore_duplicates=op.ignore_duplicates,
inplace=op.inplace,
)
new_idxs.append(idx)
if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(x, *new_idxs)
else:
# Unravel advanced indexing dimensions
raveled_shape = tuple(new_subtensor.shape)
unraveled_shape = (
*raveled_shape[:first_int_idx_pos],
*first_int_idx.shape,
*raveled_shape[first_int_idx_pos + 1 :],
)
new_out = new_subtensor.reshape(unraveled_shape)
new_out = node.op(x, y, *new_idxs)
return [copy_stack_trace(node.outputs[0], new_out)]
optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx,
"numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)
optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__,
ravel_multidimensional_int_idx,
bool_idx_to_nonzero.__name__,
bool_idx_to_nonzero,
"numba",
"shape_unsafe", # It can mask invalid mask sizes
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论