提交 d8501d14 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba AdvancedIndexing: Complete support for integer (and mixed basic) advanced indexing

When default `ignore_updates=True` for inc_subtensor, and boolean indices were rewritten during specialize
上级 fe10f960
import operator
import sys
from hashlib import sha256
from textwrap import dedent, indent
import numba
import numpy as np
......@@ -14,13 +15,13 @@ from pytensor.link.numba.cache import (
compile_numba_function_src,
)
from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
generate_fallback_impl,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
......@@ -29,7 +30,7 @@ from pytensor.tensor.subtensor import (
IncSubtensor,
Subtensor,
)
from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType
from pytensor.tensor.type_other import MakeSlice, NoneTypeT
def slice_new(self, start, stop, step):
......@@ -243,14 +244,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
else:
_x, _y, *idxs = node.inputs
basic_idxs = [
idx
for idx in idxs
if (
isinstance(idx.type, NoneTypeT)
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
)
]
adv_idxs = [
{
"axis": i,
......@@ -262,248 +255,401 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(idx.type, TensorType)
]
# Special implementation for consecutive integer vector indices
if (
not basic_idxs
and len(adv_idxs) >= 2
# Must be integer vectors
# Todo: we could allow shape=(1,) if this is the shape of x
and all(
(adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool")
for adv_idx in adv_idxs
must_ignore_duplicates = (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
and op.ignore_duplicates
# Only vector integer indices can have "duplicates", not scalars or boolean vectors
and not all(
adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool" for adv_idx in adv_idxs
)
# Must be consecutive
and not op.non_consecutive_adv_indexing(node)
)
# Special implementation for integer indices that respects duplicates
if (
not must_ignore_duplicates
and len(adv_idxs) >= 1
and all(adv_idx["dtype"] != "bool" for adv_idx in adv_idxs)
# Implementation does not support newaxis
and not any(isinstance(idx.type, NoneTypeT) for idx in idxs)
):
return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs)
return vector_integer_advanced_indexing(op, node, **kwargs)
must_respect_duplicates = (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
and not op.ignore_duplicates
# Only vector integer indices can have "duplicates", not scalars or boolean vectors
and not all(
adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool" for adv_idx in adv_idxs
)
)
# Other cases not natively supported by Numba (fallback to obj-mode)
# Cases natively supported by Numba
if (
# Numba indexing, like Numpy, ignores duplicates in update
not must_respect_duplicates
# Numba does not support indexes with more than one dimension
any(idx["ndim"] > 1 for idx in adv_idxs)
and not any(idx["ndim"] > 1 for idx in adv_idxs)
# Nor multiple vector indexes
or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1
# The default PyTensor implementation does not handle duplicate indices correctly
or (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
and not (
op.ignore_duplicates
# Only vector integer indices can have "duplicates", not scalars or boolean vectors
or all(
adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool"
for adv_idx in adv_idxs
)
)
)
and not sum(idx["ndim"] > 0 for idx in adv_idxs) > 1
):
return generate_fallback_impl(op, node, **kwargs), subtensor_op_cache_key(
op, func="fallback_impl"
)
return numba_funcify_default_subtensor(op, node, **kwargs)
# What's left should all be supported natively by numba
return numba_funcify_default_subtensor(op, node, **kwargs)
# Otherwise fallback to obj_mode
return generate_fallback_impl(op, node, **kwargs), subtensor_op_cache_key(
op, func="fallback_impl"
)
def _broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True
return False
@register_funcify_and_cache_key(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
return vector_integer_advanced_indexing(op, node=node, **kwargs)
def numba_funcify_multiple_integer_vector_indexing(
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
def vector_integer_advanced_indexing(
op: AdvancedSubtensor1 | AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
):
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
if isinstance(op, AdvancedSubtensor):
idxs = node.inputs[1:]
else:
idxs = node.inputs[2:]
"""Implement all forms of advanced indexing (and assignment) that combine basic and vector integer indices.
first_axis = next(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
)
try:
after_last_axis = next(
i
for i, idx in enumerate(idxs[first_axis:], start=first_axis)
if not isinstance(idx.type, TensorType)
)
except StopIteration:
after_last_axis = len(idxs)
last_axis = after_last_axis - 1
It does not support `newaxis` in basic indices
vector_indices = idxs[first_axis:after_last_axis]
assert all(v.type.broadcastable == (False,) for v in vector_indices)
y_is_broadcasted = False
It handles += like `np.add.at` would, accumulating add for duplicate indices.
if isinstance(op, AdvancedSubtensor):
Examples
--------
Codegen for an AdvancedSubtensor, with non-consecutive matrix indices, and a slice(1, None) basic index
@numba_basic.numba_njit
def advanced_subtensor_multiple_vector(x, *idxs):
none_slices = idxs[:first_axis]
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
idx_shape = vec_idxs[0].shape
shape_bef = x_shape[:first_axis]
shape_aft = x_shape[after_last_axis:]
out_shape = (*shape_bef, *idx_shape, *shape_aft)
out_buffer = np.empty(out_shape, dtype=x.dtype)
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
.. code-block:: python
# AdvancedSubtensor [id A] <Tensor3(int64, shape=(2, 2, 3))>
# ├─ <Tensor3(int64, shape=(3, 4, 5))> [id B] <Tensor3(int64, shape=(3, 4, 5))>
# ├─ [[1 2] [2 1]] [id C] <Matrix(uint8, shape=(2, 2))>
# ├─ SliceConstant{1, None, None} [id D] <slice>
# └─ [[0 0] [0 0]] [id E] <Matrix(uint8, shape=(2, 2))>
def advanced_integer_vector_indexing(x, idx0, idx1, idx2):
# Move advanced indexed dims to the front (if needed)
x_adv_dims_front = x.transpose((0, 2, 1))
# Perform basic indexing once (if needed)
basic_indexed_x = x_adv_dims_front[:, :, idx1]
# Broadcast indices
adv_idx_shape = np.broadcast_shapes(idx0.shape, idx2.shape)
(idx0, idx2) = (
np.broadcast_to(idx0, adv_idx_shape),
np.broadcast_to(idx2, adv_idx_shape),
)
# Create output buffer
adv_idx_size = idx0.size
basic_idx_shape = basic_indexed_x.shape[2:]
out_buffer = np.empty((adv_idx_size, *basic_idx_shape), dtype=x.dtype)
# Index over tuples of raveled advanced indices and write to output buffer
for i, scalar_idxs in enumerate(zip(idx0.ravel(), idx2.ravel())):
out_buffer[i] = basic_indexed_x[scalar_idxs]
# Unravel out_buffer (if needed)
out_buffer = out_buffer.reshape((*adv_idx_shape, *basic_idx_shape))
# Move advanced output indexing group to its final position (if needed) and return
return out_buffer
ret_func = advanced_subtensor_multiple_vector
else:
inplace = op.inplace
# Check if y must be broadcasted
# Includes the last integer vector index,
x, y = node.inputs[:2]
indexed_bcast_dims = (
*x.type.broadcastable[:first_axis],
*x.type.broadcastable[last_axis:],
)
y_is_broadcasted = _broadcasted_to(y.type.broadcastable, indexed_bcast_dims)
Codegen for similar AdvancedSetSubtensor
if op.set_instead_of_inc:
.. code-block::python
@numba_basic.numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
AdvancedSetSubtensor [id A] <Tensor3(int64, shape=(3, 4, 5))>
├─ x [id B] <Tensor3(int64, shape=(3, 4, 5))>
├─ y [id C] <Matrix(int64, shape=(2, 4))>
├─ [1 2] [id D] <Vector(uint8, shape=(2,))>
├─ SliceConstant{None, None, None} [id E] <slice>
└─ [3 4] [id F] <Vector(uint8, shape=(2,))>
if inplace:
out = x
else:
out = x.copy()
def set_advanced_integer_vector_indexing(x, y, idx0, idx1, idx2):
# Expand dims of y explicitly (if needed)
y = y
if y_is_broadcasted:
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
# Copy x (if not inplace)
x = x.copy()
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out
# Move advanced indexed dims to the front (if needed)
# This will remain a view of x
x_adv_dims_front = x.transpose((0, 2, 1))
ret_func = advanced_set_subtensor_multiple_vector
# Perform basic indexing once (if needed)
# This will remain a view of x
basic_indexed_x = x_adv_dims_front[:, :, idx1]
else:
# Broadcast indices
adv_idx_shape = np.broadcast_shapes(idx0.shape, idx2.shape)
(idx0, idx2) = (np.broadcast_to(idx0, adv_idx_shape), np.broadcast_to(idx2, adv_idx_shape))
@numba_basic.numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
# Move advanced indexed dims to the front (if needed)
y_adv_dims_front = y
if inplace:
out = x
else:
out = x.copy()
# Broadcast y to the shape of each assignment/update
adv_idx_shape = idx0.shape
basic_idx_shape = basic_indexed_x.shape[2:]
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
if y_is_broadcasted:
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = y_bcast
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out
# Index over tuples of raveled advanced indices and update buffer
for i, scalar_idxs in enumerate(zip(idx0, idx2)):
basic_indexed_x[scalar_idxs] = y_bcast[i]
ret_func = advanced_inc_subtensor_multiple_vector
# Return the original x, with the entries updated
return x
cache_key = subtensor_op_cache_key(
op,
func="multiple_integer_vector_indexing",
y_is_broadcasted=y_is_broadcasted,
first_axis=first_axis,
last_axis=last_axis,
)
return ret_func, cache_key
Codegen for an AdvancedIncSubtensor, with two contiguous advanced groups not in the leading axis
@register_funcify_and_cache_key(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
x, vals, _idxs = node.inputs
broadcast_with_index = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
# TODO: Add runtime_broadcast check
if set_instead_of_inc:
if broadcast_with_index:
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] = core_val
return x
.. code-block::python
else:
AdvancedIncSubtensor [id A] <Tensor3(int64, shape=(3, 4, 5))>
├─ x [id B] <Tensor3(int64, shape=(3, 4, 5))>
├─ y [id C] <Matrix(int64, shape=(2, 2))>
├─ SliceConstant{1, None, None} [id D] <slice>
├─ [1 2] [id E] <Vector(uint8, shape=(2,))>
└─ [3 4] [id F] <Vector(uint8, shape=(2,))>
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
# no strict argument because incompatible with numba
for idx, val in zip(idxs, vals):
x[idx] = val
return x
def inc_advanced_integer_vector_indexing(x, y, idx0, idx1, idx2):
# Expand dims of y explicitly (if needed)
y = y
# Copy x (if not inplace)
x = x.copy()
# Move advanced indexed dims to the front (if needed)
# This will remain a view of x
x_adv_dims_front = x.transpose((1, 2, 0))
# Perform basic indexing once (if needed)
# This will remain a view of x
basic_indexed_x = x_adv_dims_front[:, :, idx0]
# Broadcast indices
adv_idx_shape = np.broadcast_shapes(idx1.shape, idx2.shape)
(idx1, idx2) = (np.broadcast_to(idx1, adv_idx_shape), np.broadcast_to(idx2, adv_idx_shape))
# Move advanced indexed dims to the front (if needed)
y_adv_dims_front = y.transpose((1, 0))
# Broadcast y to the shape of each assignment/update
adv_idx_shape = idx1.shape
basic_idx_shape = basic_indexed_x.shape[2:]
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = y_bcast
# Index over tuples of raveled advanced indices and update buffer
for i, scalar_idxs in enumerate(zip(idx1, idx2)):
basic_indexed_x[scalar_idxs] += y_bcast[i]
# Return the original x, with the entries updated
return x
"""
if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor):
x, *idxs = node.inputs
else:
if broadcast_with_index:
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] += core_val
return x
x, y, *idxs = node.inputs
[out] = node.outputs
else:
adv_indices_pos = tuple(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
)
assert adv_indices_pos # Otherwise it's just basic indexing
basic_indices_pos = tuple(
i for i, idx in enumerate(idxs) if not isinstance(idx.type, TensorType)
)
explicit_basic_indices_pos = (*basic_indices_pos, *range(len(idxs), x.type.ndim))
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
# no strict argument because unsupported by numba
# TODO: this doesn't come up in tests
for idx, val in zip(idxs, vals):
x[idx] += val
return x
# Create index signature and split them among basic and advanced
idx_signature = ", ".join(f"idx{i}" for i in range(len(idxs)))
adv_indices = [f"idx{i}" for i in adv_indices_pos]
basic_indices = [f"idx{i}" for i in basic_indices_pos]
cache_key = subtensor_op_cache_key(
op,
func="numba_funcify_advancedincsubtensor1",
broadcast_with_index=broadcast_with_index,
# Define transpose axis so that advanced indexing dims are on the front
adv_axis_front_order = (*adv_indices_pos, *explicit_basic_indices_pos)
adv_axis_front_transpose_needed = adv_axis_front_order != tuple(range(x.ndim))
adv_idx_ndim = max(idxs[i].ndim for i in adv_indices_pos)
# Helper needed for basic indexing after moving advanced indices to the front
basic_indices_with_none_slices = ", ".join(
(*((":",) * len(adv_indices)), *basic_indices)
)
if inplace:
return advancedincsubtensor1_inplace, cache_key
# Position of the first advanced index dimension after indexing the array
if (np.diff(adv_indices_pos) > 1).any():
# If not consecutive, it's always at the front
out_adv_axis_pos = 0
else:
# Otherwise wherever the first advanced index is located
out_adv_axis_pos = adv_indices_pos[0]
to_tuple = create_tuple_string # alias to make code more readable below
if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor):
# Define transpose axis on the output to restore original meaning
# After (potentially) having transposed advanced indexing dims to the front unlike numpy
_final_axis_order = list(range(adv_idx_ndim, out.type.ndim))
for i in range(adv_idx_ndim):
_final_axis_order.insert(out_adv_axis_pos + i, i)
final_axis_order = tuple(_final_axis_order)
del _final_axis_order
final_axis_transpose_needed = final_axis_order != tuple(range(out.type.ndim))
func_name = "advanced_integer_vector_indexing"
codegen = dedent(
f"""
def {func_name}(x, {idx_signature}):
# Move advanced indexed dims to the front (if needed)
x_adv_dims_front = {f"x.transpose({adv_axis_front_order})" if adv_axis_front_transpose_needed else "x"}
# Perform basic indexing once (if needed)
basic_indexed_x = {f"x_adv_dims_front[{basic_indices_with_none_slices}]" if basic_indices else "x_adv_dims_front"}
"""
)
if len(adv_indices) > 1:
codegen += indent(
dedent(
f"""
# Broadcast indices
adv_idx_shape = np.broadcast_shapes{to_tuple([f"{idx}.shape" for idx in adv_indices])}
{to_tuple(adv_indices)} = {to_tuple([f"np.broadcast_to({idx}, adv_idx_shape)" for idx in adv_indices])}
"""
),
" " * 4,
)
codegen += indent(
dedent(
f"""
# Create output buffer
adv_idx_size = {adv_indices[0]}.size
basic_idx_shape = basic_indexed_x.shape[{len(adv_indices)}:]
out_buffer = np.empty((adv_idx_size, *basic_idx_shape), dtype=x.dtype)
# Index over tuples of raveled advanced indices and write to output buffer
for i, scalar_idxs in enumerate(zip{to_tuple([f"{idx}.ravel()" for idx in adv_indices] if adv_idx_ndim != 1 else adv_indices)}):
out_buffer[i] = basic_indexed_x[scalar_idxs]
# Unravel out_buffer (if needed)
out_buffer = {f"out_buffer.reshape((*{adv_indices[0]}.shape, *basic_idx_shape))" if adv_idx_ndim != 1 else "out_buffer"}
# Move advanced output indexing group to its final position (if needed) and return
return {f"out_buffer.transpose({final_axis_order})" if final_axis_transpose_needed else "out_buffer"}
"""
),
" " * 4,
)
else:
# Make implicit dims of y explicit to simplify code
# Numba doesn't support `np.expand_dims` with multiple axis, so we use indexing with newaxis
indexed_ndim = x[tuple(idxs)].type.ndim
y_expand_dims = [":"] * y.type.ndim
y_implicit_dims = range(indexed_ndim - y.type.ndim)
for axis in y_implicit_dims:
y_expand_dims.insert(axis, "None")
# We transpose the advanced dimensions of x to the front for indexing
# We may have to do the same for y
# Note that if there are non-contiguous advanced indices,
# y must already be aligned with the indices jumping to the front
y_adv_axis_front_order = tuple(
range(
# Position of the first advanced axis after indexing
out_adv_axis_pos,
# Position of the last advanced axis after indexing
out_adv_axis_pos + adv_idx_ndim,
)
)
y_order = tuple(range(indexed_ndim))
y_adv_axis_front_order = (
*y_adv_axis_front_order,
# Basic indices, after explicit_expand_dims
*(o for o in y_order if o not in y_adv_axis_front_order),
)
y_adv_axis_front_transpose_needed = y_adv_axis_front_order != y_order
func_name = f"{'set' if op.set_instead_of_inc else 'inc'}_advanced_integer_vector_indexing"
codegen = dedent(
f"""
def {func_name}(x, y, {idx_signature}):
# Expand dims of y explicitly (if needed)
y = {f"y[{', '.join(y_expand_dims)},]" if y_implicit_dims else "y"}
# Copy x (if not inplace)
x = {"x" if op.inplace else "x.copy()"}
# Move advanced indexed dims to the front (if needed)
# This will remain a view of x
x_adv_dims_front = {f"x.transpose({adv_axis_front_order})" if adv_axis_front_transpose_needed else "x"}
# Perform basic indexing once (if needed)
# This will remain a view of x
basic_indexed_x = {f"x_adv_dims_front[{basic_indices_with_none_slices}]" if basic_indices else "x_adv_dims_front"}
"""
)
if len(adv_indices) > 1:
codegen += indent(
dedent(
f"""
# Broadcast indices
adv_idx_shape = np.broadcast_shapes{to_tuple([f"{idx}.shape" for idx in adv_indices])}
{to_tuple(adv_indices)} = {to_tuple([f"np.broadcast_to({idx}, adv_idx_shape)" for idx in adv_indices])}
"""
),
" " * 4,
)
codegen += indent(
dedent(
f"""
# Move advanced indexed dims to the front (if needed)
y_adv_dims_front = {f"y.transpose({y_adv_axis_front_order})" if y_adv_axis_front_transpose_needed else "y"}
# Broadcast y to the shape of each assignment/update
adv_idx_shape = {adv_indices[0]}.shape
basic_idx_shape = basic_indexed_x.shape[{len(adv_indices)}:]
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = {"y_bcast.ravel().reshape((-1, *basic_idx_shape))" if adv_idx_ndim != 1 else "y_bcast"}
# Index over tuples of raveled advanced indices and update buffer
for i, scalar_idxs in enumerate(zip{to_tuple([f"{idx}.ravel()" for idx in adv_indices] if adv_idx_ndim != 1 else adv_indices)}):
basic_indexed_x[scalar_idxs] {"=" if op.set_instead_of_inc else "+="} y_bcast[i]
# Return the original x, with the entries updated
return x
"""
),
" " * 4,
)
@numba_basic.numba_njit
def advancedincsubtensor1(x, vals, idxs):
x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs)
cache_key = subtensor_op_cache_key(
op,
codegen=codegen,
)
return advancedincsubtensor1, cache_key
ret_func = numba_basic.numba_njit(
compile_numba_function_src(
codegen,
function_name=func_name,
global_env=globals(),
cache_key=cache_key,
)
)
return ret_func, cache_key
......@@ -83,7 +83,7 @@ from pytensor.tensor.subtensor import (
inc_subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -1744,205 +1744,45 @@ def local_blockwise_inc_subtensor(fgraph, node):
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_bool_idx(fgraph, node):
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
def bool_idx_to_nonzero(fgraph, node):
"""Convert boolean indexing into equivalent vector boolean index, supported by our dispatch
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
x[1:, eye(3, dtype=bool), 1:] -> x[1:, *eye(3).nonzero()]
"""
if isinstance(node.op, AdvancedSubtensor):
x, *idxs = node.inputs
else:
x, y, *idxs = node.inputs
if any(
(
(isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes)
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indexes or np.newaxis
return None
bool_idxs = [
(i, idx)
bool_pos = {
i
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
]
if len(bool_idxs) != 1:
# Get out if there are no or multiple boolean idxs
return None
}
[(bool_idx_pos, bool_idx)] = bool_idxs
bool_idx_ndim = bool_idx.type.ndim
if bool_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
if not bool_pos:
return None
x_shape = x.shape
raveled_x = x.reshape(
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :])
)
raveled_bool_idx = bool_idx.ravel()
new_idxs = list(idxs)
new_idxs[bool_idx_pos] = raveled_bool_idx
new_idxs = []
for i, idx in enumerate(idxs):
if i in bool_pos:
new_idxs.extend(idx.nonzero())
else:
new_idxs.append(idx)
if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(raveled_x, *new_idxs)
new_out = node.op(x, *new_idxs)
else:
# The dimensions of y that correspond to the boolean indices
# must already be raveled in the original graph, so we don't need to do anything to it
new_out = node.op(raveled_x, y, *new_idxs)
# But we must reshape the output to math the original shape
new_out = new_out.reshape(x_shape)
new_out = node.op(x, y, *new_idxs)
return [copy_stack_trace(node.outputs[0], new_out)]
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
supported by Numba or by our specialized dispatchers
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
It also handles multiple integer indices, but only if they don't broadcast
x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
"""
op = node.op
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)
if is_inc_subtensor:
x, y, *idxs = node.inputs
# Inc/SetSubtensor is harder to reason about due to y
# We get out if it's broadcasting or if the advanced indices are non-consecutive
if non_consecutive_adv_indexing or (
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
):
return None
else:
x, *idxs = node.inputs
if any(
(
(isinstance(idx.type, TensorType) and idx.type.dtype == "bool")
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indices or np.newaxis
return None
int_idxs_and_pos = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
]
if not int_idxs_and_pos:
return None
int_idxs_pos, int_idxs = zip(
*int_idxs_and_pos, strict=False
) # strict=False because by definition it's true
first_int_idx_pos = int_idxs_pos[0]
first_int_idx = int_idxs[0]
first_int_idx_bcast = first_int_idx.type.broadcastable
if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
# We don't have a view-only broadcasting operation
# Explicitly broadcasting the indices can incur a memory / copy overhead
return None
int_idxs_ndim = len(first_int_idx_bcast)
if (
int_idxs_ndim == 0
): # This should be a basic indexing operation, rewrite elsewhere
return None
int_idxs_need_raveling = int_idxs_ndim > 1
if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
return None
# Reorder non-consecutive indices
if non_consecutive_adv_indexing:
assert not is_inc_subtensor # Sanity check that we got out if this was the case
# This case works as if all the advanced indices were on the front
transposition = list(int_idxs_pos) + [
i for i in range(len(idxs)) if i not in int_idxs_pos
]
idxs = tuple(idxs[a] for a in transposition)
x = x.transpose(transposition)
first_int_idx_pos = 0
del int_idxs_pos # Make sure they are not wrongly used
# Ravel multidimensional indices
if int_idxs_need_raveling:
idxs = list(idxs)
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
idxs[idx_pos] = int_idx.ravel()
# Index with reordered and/or raveled indices
new_subtensor = x[tuple(idxs)]
if is_inc_subtensor:
y_shape = tuple(y.shape)
y_raveled_shape = (
*y_shape[:first_int_idx_pos],
-1,
*y_shape[first_int_idx_pos + int_idxs_ndim :],
)
y_raveled = y.reshape(y_raveled_shape)
new_out = inc_subtensor(
new_subtensor,
y_raveled,
set_instead_of_inc=op.set_instead_of_inc,
ignore_duplicates=op.ignore_duplicates,
inplace=op.inplace,
)
else:
# Unravel advanced indexing dimensions
raveled_shape = tuple(new_subtensor.shape)
unraveled_shape = (
*raveled_shape[:first_int_idx_pos],
*first_int_idx.shape,
*raveled_shape[first_int_idx_pos + 1 :],
)
new_out = new_subtensor.reshape(unraveled_shape)
return [copy_stack_trace(node.outputs[0], new_out)]
optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx,
"numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)
optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__,
ravel_multidimensional_int_idx,
bool_idx_to_nonzero.__name__,
bool_idx_to_nonzero,
"numba",
"shape_unsafe", # It can mask invalid mask sizes
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)
......
......@@ -109,117 +109,95 @@ def test_AdvancedSubtensor1_out_of_bounds():
@pytest.mark.parametrize(
"x, indices, objmode_needed",
"x, indices",
[
# Single vector indexing (supported natively by Numba)
# Single vector indexing
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(0, [1, 2, 2, 3]),
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(np.array([True, False, False])),
False,
),
# Single multidimensional indexing (supported after specialization rewrites)
# Single multidimensional indexing
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(bool)),
False,
),
(
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(bool)),
False,
),
(
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(bool)),
False,
),
# Multiple vector indexing (supported by our dispatcher)
# Multiple vector indexing
(
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], [2, 3]),
False,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None), [1, 2], [3, 4]),
False,
),
(
as_tensor(np.arange(3 * 5 * 7).reshape((3, 5, 7))),
([1, 2], [3, 4], [5, 6]),
False,
),
# Non-consecutive vector indexing, supported by our dispatcher after rewriting
# Non-consecutive vector indexing
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]),
False,
),
# Multiple multidimensional integer indexing (supported by our dispatcher)
# Multiple multidimensional integer indexing
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [[0, 0], [0, 0]]),
False,
),
(
as_tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))),
(slice(None), [[1, 2], [2, 1]], slice(None), [[0, 0], [0, 0]]),
False,
),
# Multiple multidimensional indexing with broadcasting, only supported in obj mode
# Multiple multidimensional indexing with broadcasting
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [0, 0]),
True,
),
# multiple multidimensional integer indexing mixed with basic indexing, only supported in obj mode
# multiple multidimensional integer indexing mixed with basic indexing
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
True,
),
],
)
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
def test_AdvancedSubtensor(x, indices, objmode_needed):
def test_AdvancedSubtensor(x, indices):
"""Test NumPy's advanced indexing in more than one dimension."""
x_pt = x.type()
out_pt = x_pt[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor)
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedSubtensor's perform method",
)
if objmode_needed
else contextlib.nullcontext()
):
compare_numba_and_py(
[x_pt],
[out_pt],
[x.data],
numba_mode=numba_mode.including("specialize"),
)
compare_numba_and_py(
[x_pt],
[out_pt],
[x.data],
# Specialize allows running boolean indexing without falling back to object mode
# Thanks to bool_idx_to_nonzero rewrite
numba_mode=numba_mode.including("specialize"),
)
@pytest.mark.parametrize(
......@@ -323,7 +301,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
@pytest.mark.parametrize(
"x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode",
"x, y, indices, duplicate_indices, duplicate_indices_require_obj_mode",
[
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -331,7 +309,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector index
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -343,7 +320,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
), # Mixed basic and broadcasted vector idx
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -351,7 +327,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector idx
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -359,7 +334,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
True,
False,
True,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -367,21 +341,11 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
True,
False,
True,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
(np.array([True, False, False])), # Broadcasted boolean index
False, # It shouldn't matter what we set this to, boolean indices cannot be duplicate
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
(np.array([True, False, False])), # Broadcasted boolean index
True, # It shouldn't matter what we set this to, boolean indices cannot be duplicate
False,
False,
),
......@@ -391,7 +355,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(np.eye(3).astype(bool)), # Boolean index
False,
False,
False,
),
(
np.arange(3 * 3 * 5).reshape((3, 3, 5)),
......@@ -402,7 +365,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
), # Boolean index, mixed with basic index
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -410,7 +372,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([1, 2], [2, 3]), # 2 vector indices
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -418,7 +379,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(None), [1, 2], [2, 3]), # 2 vector indices
False,
False,
False,
),
(
np.arange(3 * 4 * 6).reshape((3, 4, 6)),
......@@ -426,7 +386,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([1, 2], [2, 3], [4, 5]), # 3 vector indices
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -434,15 +393,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([1, 2], [2, 3]), # 2 vector indices
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 4)),
([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices
False,
True,
True,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -453,8 +410,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
[3, 4],
), # Mixed double vector index and basic index
False,
True,
True,
False,
),
(
np.arange(5),
......@@ -462,7 +418,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([[1, 2], [2, 3]]), # matrix index
False,
False,
False,
),
(
np.arange(3 * 5).reshape((3, 5)),
......@@ -470,23 +425,20 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(1, 3), [[1, 2], [2, 3]]), # matrix index, mixed with basic index
False,
False,
False,
),
(
np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts
(slice(1, 3), [[1, 2], [2, 3]]),
False,
True,
True,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 5)),
([1, 1], [2, 2]), # Repeated indices
True,
False,
False,
True,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......@@ -494,7 +446,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(None), [[1, 2], [2, 1]], [[2, 3], [0, 0]]), # 2 matrix indices
False,
False,
False,
),
],
)
......@@ -505,8 +456,7 @@ def test_AdvancedIncSubtensor(
y,
indices,
duplicate_indices,
set_requires_objmode,
inc_requires_objmode,
duplicate_indices_require_obj_mode,
inplace,
):
# Need rewrite to support certain forms of advanced indexing without object mode
......@@ -518,17 +468,9 @@ def test_AdvancedIncSubtensor(
out_pt = set_subtensor(x_pt[indices], y_pt, inplace=inplace)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
)
if set_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace:
# Test updates inplace
......@@ -536,23 +478,58 @@ def test_AdvancedIncSubtensor(
fn(x, y + 1)
assert not np.all(x == x_orig)
out_pt = inc_subtensor(
x_pt[indices], y_pt, ignore_duplicates=not duplicate_indices, inplace=inplace
)
out_pt = inc_subtensor(x_pt[indices], y_pt, inplace=inplace)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedIncSubtensor's perform method",
)
if inc_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace:
# Test updates inplace
x_orig = x.copy()
fn(x, y)
assert not np.all(x == x_orig)
if duplicate_indices:
# If inc_subtensor is called with `ignore_duplicates=True`, and it's not one of the cases supported by Numba
# We have to fall back to obj_mode
out_pt = inc_subtensor(
x_pt[indices], y_pt, inplace=inplace, ignore_duplicates=True
)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedIncSubtensor's perform method",
)
if duplicate_indices_require_obj_mode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace:
# Test updates inplace
x_orig = x.copy()
fn(x, y)
assert not np.all(x == x_orig)
def test_advanced_indexing_with_newaxis_fallback_obj_mode():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x = pt.matrix("x")
out = x[None, [0, 1, 2], [0, 1, 2]]
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
out = x[None, [0, 1, 2], [0, 1, 2]].inc(5)
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedIncSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论