提交 d8501d14 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba AdvancedIndexing: Complete support for integer (and mixed basic) advanced indexing

When default `ignore_updates=True` for inc_subtensor, and boolean indices were rewritten during specialize
上级 fe10f960
import operator import operator
import sys import sys
from hashlib import sha256 from hashlib import sha256
from textwrap import dedent, indent
import numba import numba
import numpy as np import numpy as np
...@@ -14,13 +15,13 @@ from pytensor.link.numba.cache import ( ...@@ -14,13 +15,13 @@ from pytensor.link.numba.cache import (
compile_numba_function_src, compile_numba_function_src,
) )
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
generate_fallback_impl, generate_fallback_impl,
register_funcify_and_cache_key, register_funcify_and_cache_key,
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
) )
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.tensor import TensorType from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -29,7 +30,7 @@ from pytensor.tensor.subtensor import ( ...@@ -29,7 +30,7 @@ from pytensor.tensor.subtensor import (
IncSubtensor, IncSubtensor,
Subtensor, Subtensor,
) )
from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType from pytensor.tensor.type_other import MakeSlice, NoneTypeT
def slice_new(self, start, stop, step): def slice_new(self, start, stop, step):
...@@ -243,14 +244,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -243,14 +244,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
else: else:
_x, _y, *idxs = node.inputs _x, _y, *idxs = node.inputs
basic_idxs = [
idx
for idx in idxs
if (
isinstance(idx.type, NoneTypeT)
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
)
]
adv_idxs = [ adv_idxs = [
{ {
"axis": i, "axis": i,
...@@ -262,248 +255,401 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -262,248 +255,401 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(idx.type, TensorType) if isinstance(idx.type, TensorType)
] ]
# Special implementation for consecutive integer vector indices must_ignore_duplicates = (
if ( isinstance(op, AdvancedIncSubtensor)
not basic_idxs and not op.set_instead_of_inc
and len(adv_idxs) >= 2 and op.ignore_duplicates
# Must be integer vectors # Only vector integer indices can have "duplicates", not scalars or boolean vectors
# Todo: we could allow shape=(1,) if this is the shape of x and not all(
and all( adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool" for adv_idx in adv_idxs
(adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool")
for adv_idx in adv_idxs
) )
# Must be consecutive )
and not op.non_consecutive_adv_indexing(node)
# Special implementation for integer indices that respects duplicates
if (
not must_ignore_duplicates
and len(adv_idxs) >= 1
and all(adv_idx["dtype"] != "bool" for adv_idx in adv_idxs)
# Implementation does not support newaxis
and not any(isinstance(idx.type, NoneTypeT) for idx in idxs)
): ):
return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs) return vector_integer_advanced_indexing(op, node, **kwargs)
must_respect_duplicates = (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
and not op.ignore_duplicates
# Only vector integer indices can have "duplicates", not scalars or boolean vectors
and not all(
adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool" for adv_idx in adv_idxs
)
)
# Other cases not natively supported by Numba (fallback to obj-mode) # Cases natively supported by Numba
if ( if (
# Numba indexing, like Numpy, ignores duplicates in update
not must_respect_duplicates
# Numba does not support indexes with more than one dimension # Numba does not support indexes with more than one dimension
any(idx["ndim"] > 1 for idx in adv_idxs) and not any(idx["ndim"] > 1 for idx in adv_idxs)
# Nor multiple vector indexes # Nor multiple vector indexes
or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1 and not sum(idx["ndim"] > 0 for idx in adv_idxs) > 1
# The default PyTensor implementation does not handle duplicate indices correctly
or (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
and not (
op.ignore_duplicates
# Only vector integer indices can have "duplicates", not scalars or boolean vectors
or all(
adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool"
for adv_idx in adv_idxs
)
)
)
): ):
return generate_fallback_impl(op, node, **kwargs), subtensor_op_cache_key( return numba_funcify_default_subtensor(op, node, **kwargs)
op, func="fallback_impl"
)
# What's left should all be supported natively by numba # Otherwise fallback to obj_mode
return numba_funcify_default_subtensor(op, node, **kwargs) return generate_fallback_impl(op, node, **kwargs), subtensor_op_cache_key(
op, func="fallback_impl"
)
def _broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): @register_funcify_and_cache_key(AdvancedIncSubtensor1)
# Check that x is not broadcasted to y based on broadcastable info def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if len(x_bcast) < len(to_bcast): return vector_integer_advanced_indexing(op, node=node, **kwargs)
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True
return False
def numba_funcify_multiple_integer_vector_indexing( def vector_integer_advanced_indexing(
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs op: AdvancedSubtensor1 | AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
): ):
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor) """Implement all forms of advanced indexing (and assignment) that combine basic and vector integer indices.
if isinstance(op, AdvancedSubtensor):
idxs = node.inputs[1:]
else:
idxs = node.inputs[2:]
first_axis = next( It does not support `newaxis` in basic indices
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
)
try:
after_last_axis = next(
i
for i, idx in enumerate(idxs[first_axis:], start=first_axis)
if not isinstance(idx.type, TensorType)
)
except StopIteration:
after_last_axis = len(idxs)
last_axis = after_last_axis - 1
vector_indices = idxs[first_axis:after_last_axis] It handles += like `np.add.at` would, accumulating add for duplicate indices.
assert all(v.type.broadcastable == (False,) for v in vector_indices)
y_is_broadcasted = False
if isinstance(op, AdvancedSubtensor): Examples
--------
Codegen for an AdvancedSubtensor, with non-consecutive matrix indices, and a slice(1, None) basic index
@numba_basic.numba_njit .. code-block:: python
def advanced_subtensor_multiple_vector(x, *idxs):
none_slices = idxs[:first_axis] # AdvancedSubtensor [id A] <Tensor3(int64, shape=(2, 2, 3))>
vec_idxs = idxs[first_axis:after_last_axis] # ├─ <Tensor3(int64, shape=(3, 4, 5))> [id B] <Tensor3(int64, shape=(3, 4, 5))>
# ├─ [[1 2] [2 1]] [id C] <Matrix(uint8, shape=(2, 2))>
x_shape = x.shape # ├─ SliceConstant{1, None, None} [id D] <slice>
idx_shape = vec_idxs[0].shape # └─ [[0 0] [0 0]] [id E] <Matrix(uint8, shape=(2, 2))>
shape_bef = x_shape[:first_axis]
shape_aft = x_shape[after_last_axis:]
out_shape = (*shape_bef, *idx_shape, *shape_aft) def advanced_integer_vector_indexing(x, idx0, idx1, idx2):
out_buffer = np.empty(out_shape, dtype=x.dtype) # Move advanced indexed dims to the front (if needed)
for i, scalar_idxs in enumerate(zip(*vec_idxs)): x_adv_dims_front = x.transpose((0, 2, 1))
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
# Perform basic indexing once (if needed)
basic_indexed_x = x_adv_dims_front[:, :, idx1]
# Broadcast indices
adv_idx_shape = np.broadcast_shapes(idx0.shape, idx2.shape)
(idx0, idx2) = (
np.broadcast_to(idx0, adv_idx_shape),
np.broadcast_to(idx2, adv_idx_shape),
)
# Create output buffer
adv_idx_size = idx0.size
basic_idx_shape = basic_indexed_x.shape[2:]
out_buffer = np.empty((adv_idx_size, *basic_idx_shape), dtype=x.dtype)
# Index over tuples of raveled advanced indices and write to output buffer
for i, scalar_idxs in enumerate(zip(idx0.ravel(), idx2.ravel())):
out_buffer[i] = basic_indexed_x[scalar_idxs]
# Unravel out_buffer (if needed)
out_buffer = out_buffer.reshape((*adv_idx_shape, *basic_idx_shape))
# Move advanced output indexing group to its final position (if needed) and return
return out_buffer return out_buffer
ret_func = advanced_subtensor_multiple_vector
else: Codegen for similar AdvancedSetSubtensor
inplace = op.inplace
# Check if y must be broadcasted
# Includes the last integer vector index,
x, y = node.inputs[:2]
indexed_bcast_dims = (
*x.type.broadcastable[:first_axis],
*x.type.broadcastable[last_axis:],
)
y_is_broadcasted = _broadcasted_to(y.type.broadcastable, indexed_bcast_dims)
if op.set_instead_of_inc: .. code-block::python
@numba_basic.numba_njit AdvancedSetSubtensor [id A] <Tensor3(int64, shape=(3, 4, 5))>
def advanced_set_subtensor_multiple_vector(x, y, *idxs): ├─ x [id B] <Tensor3(int64, shape=(3, 4, 5))>
vec_idxs = idxs[first_axis:after_last_axis] ├─ y [id C] <Matrix(int64, shape=(2, 4))>
x_shape = x.shape ├─ [1 2] [id D] <Vector(uint8, shape=(2,))>
├─ SliceConstant{None, None, None} [id E] <slice>
└─ [3 4] [id F] <Vector(uint8, shape=(2,))>
if inplace: def set_advanced_integer_vector_indexing(x, y, idx0, idx1, idx2):
out = x # Expand dims of y explicitly (if needed)
else: y = y
out = x.copy()
if y_is_broadcasted: # Copy x (if not inplace)
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) x = x.copy()
for outer in np.ndindex(x_shape[:first_axis]): # Move advanced indexed dims to the front (if needed)
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # This will remain a view of x
out[(*outer, *scalar_idxs)] = y[(*outer, i)] x_adv_dims_front = x.transpose((0, 2, 1))
return out
ret_func = advanced_set_subtensor_multiple_vector # Perform basic indexing once (if needed)
# This will remain a view of x
basic_indexed_x = x_adv_dims_front[:, :, idx1]
else: # Broadcast indices
adv_idx_shape = np.broadcast_shapes(idx0.shape, idx2.shape)
(idx0, idx2) = (np.broadcast_to(idx0, adv_idx_shape), np.broadcast_to(idx2, adv_idx_shape))
@numba_basic.numba_njit # Move advanced indexed dims to the front (if needed)
def advanced_inc_subtensor_multiple_vector(x, y, *idxs): y_adv_dims_front = y
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
if inplace: # Broadcast y to the shape of each assignment/update
out = x adv_idx_shape = idx0.shape
else: basic_idx_shape = basic_indexed_x.shape[2:]
out = x.copy() y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
if y_is_broadcasted: # Ravel the advanced dims (if needed)
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) # Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = y_bcast
for outer in np.ndindex(x_shape[:first_axis]): # Index over tuples of raveled advanced indices and update buffer
for i, scalar_idxs in enumerate(zip(*vec_idxs)): for i, scalar_idxs in enumerate(zip(idx0, idx2)):
out[(*outer, *scalar_idxs)] += y[(*outer, i)] basic_indexed_x[scalar_idxs] = y_bcast[i]
return out
ret_func = advanced_inc_subtensor_multiple_vector # Return the original x, with the entries updated
return x
cache_key = subtensor_op_cache_key(
op,
func="multiple_integer_vector_indexing",
y_is_broadcasted=y_is_broadcasted,
first_axis=first_axis,
last_axis=last_axis,
)
return ret_func, cache_key
Codegen for an AdvancedIncSubtensor, with two contiguous advanced groups not in the leading axis
@register_funcify_and_cache_key(AdvancedIncSubtensor1) .. code-block::python
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
x, vals, _idxs = node.inputs
broadcast_with_index = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
# TODO: Add runtime_broadcast check
if set_instead_of_inc:
if broadcast_with_index:
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] = core_val
return x
else: AdvancedIncSubtensor [id A] <Tensor3(int64, shape=(3, 4, 5))>
├─ x [id B] <Tensor3(int64, shape=(3, 4, 5))>
├─ y [id C] <Matrix(int64, shape=(2, 2))>
├─ SliceConstant{1, None, None} [id D] <slice>
├─ [1 2] [id E] <Vector(uint8, shape=(2,))>
└─ [3 4] [id F] <Vector(uint8, shape=(2,))>
@numba_basic.numba_njit(boundscheck=True) def inc_advanced_integer_vector_indexing(x, y, idx0, idx1, idx2):
def advancedincsubtensor1_inplace(x, vals, idxs): # Expand dims of y explicitly (if needed)
if not len(idxs) == len(vals): y = y
raise ValueError("The number of indices and values must match.")
# no strict argument because incompatible with numba # Copy x (if not inplace)
for idx, val in zip(idxs, vals): x = x.copy()
x[idx] = val
return x # Move advanced indexed dims to the front (if needed)
# This will remain a view of x
x_adv_dims_front = x.transpose((1, 2, 0))
# Perform basic indexing once (if needed)
# This will remain a view of x
basic_indexed_x = x_adv_dims_front[:, :, idx0]
# Broadcast indices
adv_idx_shape = np.broadcast_shapes(idx1.shape, idx2.shape)
(idx1, idx2) = (np.broadcast_to(idx1, adv_idx_shape), np.broadcast_to(idx2, adv_idx_shape))
# Move advanced indexed dims to the front (if needed)
y_adv_dims_front = y.transpose((1, 0))
# Broadcast y to the shape of each assignment/update
adv_idx_shape = idx1.shape
basic_idx_shape = basic_indexed_x.shape[2:]
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = y_bcast
# Index over tuples of raveled advanced indices and update buffer
for i, scalar_idxs in enumerate(zip(idx1, idx2)):
basic_indexed_x[scalar_idxs] += y_bcast[i]
# Return the original x, with the entries updated
return x
"""
if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor):
x, *idxs = node.inputs
else: else:
if broadcast_with_index: x, y, *idxs = node.inputs
[out] = node.outputs
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
elif val.ndim == 0:
# Workaround for https://github.com/numba/numba/issues/9573
core_val = val.item()
else:
core_val = val
for idx in idxs:
x[idx] += core_val
return x
else: adv_indices_pos = tuple(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
)
assert adv_indices_pos # Otherwise it's just basic indexing
basic_indices_pos = tuple(
i for i, idx in enumerate(idxs) if not isinstance(idx.type, TensorType)
)
explicit_basic_indices_pos = (*basic_indices_pos, *range(len(idxs), x.type.ndim))
@numba_basic.numba_njit(boundscheck=True) # Create index signature and split them among basic and advanced
def advancedincsubtensor1_inplace(x, vals, idxs): idx_signature = ", ".join(f"idx{i}" for i in range(len(idxs)))
if not len(idxs) == len(vals): adv_indices = [f"idx{i}" for i in adv_indices_pos]
raise ValueError("The number of indices and values must match.") basic_indices = [f"idx{i}" for i in basic_indices_pos]
# no strict argument because unsupported by numba
# TODO: this doesn't come up in tests
for idx, val in zip(idxs, vals):
x[idx] += val
return x
cache_key = subtensor_op_cache_key( # Define transpose axis so that advanced indexing dims are on the front
op, adv_axis_front_order = (*adv_indices_pos, *explicit_basic_indices_pos)
func="numba_funcify_advancedincsubtensor1", adv_axis_front_transpose_needed = adv_axis_front_order != tuple(range(x.ndim))
broadcast_with_index=broadcast_with_index, adv_idx_ndim = max(idxs[i].ndim for i in adv_indices_pos)
# Helper needed for basic indexing after moving advanced indices to the front
basic_indices_with_none_slices = ", ".join(
(*((":",) * len(adv_indices)), *basic_indices)
) )
if inplace: # Position of the first advanced index dimension after indexing the array
return advancedincsubtensor1_inplace, cache_key if (np.diff(adv_indices_pos) > 1).any():
# If not consecutive, it's always at the front
out_adv_axis_pos = 0
else:
# Otherwise wherever the first advanced index is located
out_adv_axis_pos = adv_indices_pos[0]
to_tuple = create_tuple_string # alias to make code more readable below
if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor):
# Define transpose axis on the output to restore original meaning
# After (potentially) having transposed advanced indexing dims to the front unlike numpy
_final_axis_order = list(range(adv_idx_ndim, out.type.ndim))
for i in range(adv_idx_ndim):
_final_axis_order.insert(out_adv_axis_pos + i, i)
final_axis_order = tuple(_final_axis_order)
del _final_axis_order
final_axis_transpose_needed = final_axis_order != tuple(range(out.type.ndim))
func_name = "advanced_integer_vector_indexing"
codegen = dedent(
f"""
def {func_name}(x, {idx_signature}):
# Move advanced indexed dims to the front (if needed)
x_adv_dims_front = {f"x.transpose({adv_axis_front_order})" if adv_axis_front_transpose_needed else "x"}
# Perform basic indexing once (if needed)
basic_indexed_x = {f"x_adv_dims_front[{basic_indices_with_none_slices}]" if basic_indices else "x_adv_dims_front"}
"""
)
if len(adv_indices) > 1:
codegen += indent(
dedent(
f"""
# Broadcast indices
adv_idx_shape = np.broadcast_shapes{to_tuple([f"{idx}.shape" for idx in adv_indices])}
{to_tuple(adv_indices)} = {to_tuple([f"np.broadcast_to({idx}, adv_idx_shape)" for idx in adv_indices])}
"""
),
" " * 4,
)
codegen += indent(
dedent(
f"""
# Create output buffer
adv_idx_size = {adv_indices[0]}.size
basic_idx_shape = basic_indexed_x.shape[{len(adv_indices)}:]
out_buffer = np.empty((adv_idx_size, *basic_idx_shape), dtype=x.dtype)
# Index over tuples of raveled advanced indices and write to output buffer
for i, scalar_idxs in enumerate(zip{to_tuple([f"{idx}.ravel()" for idx in adv_indices] if adv_idx_ndim != 1 else adv_indices)}):
out_buffer[i] = basic_indexed_x[scalar_idxs]
# Unravel out_buffer (if needed)
out_buffer = {f"out_buffer.reshape((*{adv_indices[0]}.shape, *basic_idx_shape))" if adv_idx_ndim != 1 else "out_buffer"}
# Move advanced output indexing group to its final position (if needed) and return
return {f"out_buffer.transpose({final_axis_order})" if final_axis_transpose_needed else "out_buffer"}
"""
),
" " * 4,
)
else: else:
# Make implicit dims of y explicit to simplify code
# Numba doesn't support `np.expand_dims` with multiple axis, so we use indexing with newaxis
indexed_ndim = x[tuple(idxs)].type.ndim
y_expand_dims = [":"] * y.type.ndim
y_implicit_dims = range(indexed_ndim - y.type.ndim)
for axis in y_implicit_dims:
y_expand_dims.insert(axis, "None")
# We transpose the advanced dimensions of x to the front for indexing
# We may have to do the same for y
# Note that if there are non-contiguous advanced indices,
# y must already be aligned with the indices jumping to the front
y_adv_axis_front_order = tuple(
range(
# Position of the first advanced axis after indexing
out_adv_axis_pos,
# Position of the last advanced axis after indexing
out_adv_axis_pos + adv_idx_ndim,
)
)
y_order = tuple(range(indexed_ndim))
y_adv_axis_front_order = (
*y_adv_axis_front_order,
# Basic indices, after explicit_expand_dims
*(o for o in y_order if o not in y_adv_axis_front_order),
)
y_adv_axis_front_transpose_needed = y_adv_axis_front_order != y_order
func_name = f"{'set' if op.set_instead_of_inc else 'inc'}_advanced_integer_vector_indexing"
codegen = dedent(
f"""
def {func_name}(x, y, {idx_signature}):
# Expand dims of y explicitly (if needed)
y = {f"y[{', '.join(y_expand_dims)},]" if y_implicit_dims else "y"}
# Copy x (if not inplace)
x = {"x" if op.inplace else "x.copy()"}
# Move advanced indexed dims to the front (if needed)
# This will remain a view of x
x_adv_dims_front = {f"x.transpose({adv_axis_front_order})" if adv_axis_front_transpose_needed else "x"}
# Perform basic indexing once (if needed)
# This will remain a view of x
basic_indexed_x = {f"x_adv_dims_front[{basic_indices_with_none_slices}]" if basic_indices else "x_adv_dims_front"}
"""
)
if len(adv_indices) > 1:
codegen += indent(
dedent(
f"""
# Broadcast indices
adv_idx_shape = np.broadcast_shapes{to_tuple([f"{idx}.shape" for idx in adv_indices])}
{to_tuple(adv_indices)} = {to_tuple([f"np.broadcast_to({idx}, adv_idx_shape)" for idx in adv_indices])}
"""
),
" " * 4,
)
codegen += indent(
dedent(
f"""
# Move advanced indexed dims to the front (if needed)
y_adv_dims_front = {f"y.transpose({y_adv_axis_front_order})" if y_adv_axis_front_transpose_needed else "y"}
# Broadcast y to the shape of each assignment/update
adv_idx_shape = {adv_indices[0]}.shape
basic_idx_shape = basic_indexed_x.shape[{len(adv_indices)}:]
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = {"y_bcast.ravel().reshape((-1, *basic_idx_shape))" if adv_idx_ndim != 1 else "y_bcast"}
# Index over tuples of raveled advanced indices and update buffer
for i, scalar_idxs in enumerate(zip{to_tuple([f"{idx}.ravel()" for idx in adv_indices] if adv_idx_ndim != 1 else adv_indices)}):
basic_indexed_x[scalar_idxs] {"=" if op.set_instead_of_inc else "+="} y_bcast[i]
# Return the original x, with the entries updated
return x
"""
),
" " * 4,
)
@numba_basic.numba_njit cache_key = subtensor_op_cache_key(
def advancedincsubtensor1(x, vals, idxs): op,
x = x.copy() codegen=codegen,
return advancedincsubtensor1_inplace(x, vals, idxs) )
return advancedincsubtensor1, cache_key ret_func = numba_basic.numba_njit(
compile_numba_function_src(
codegen,
function_name=func_name,
global_env=globals(),
cache_key=cache_key,
)
)
return ret_func, cache_key
...@@ -83,7 +83,7 @@ from pytensor.tensor.subtensor import ( ...@@ -83,7 +83,7 @@ from pytensor.tensor.subtensor import (
inc_subtensor, inc_subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -1744,205 +1744,45 @@ def local_blockwise_inc_subtensor(fgraph, node): ...@@ -1744,205 +1744,45 @@ def local_blockwise_inc_subtensor(fgraph, node):
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor]) @node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_bool_idx(fgraph, node): def bool_idx_to_nonzero(fgraph, node):
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba """Convert boolean indexing into equivalent vector boolean index, supported by our dispatch
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()] x[1:, eye(3, dtype=bool), 1:] -> x[1:, *eye(3).nonzero()]
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
""" """
if isinstance(node.op, AdvancedSubtensor): if isinstance(node.op, AdvancedSubtensor):
x, *idxs = node.inputs x, *idxs = node.inputs
else: else:
x, y, *idxs = node.inputs x, y, *idxs = node.inputs
if any( bool_pos = {
( i
(isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes)
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indexes or np.newaxis
return None
bool_idxs = [
(i, idx)
for i, idx in enumerate(idxs) for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype == "bool") if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
] }
if len(bool_idxs) != 1:
# Get out if there are no or multiple boolean idxs
return None
[(bool_idx_pos, bool_idx)] = bool_idxs if not bool_pos:
bool_idx_ndim = bool_idx.type.ndim
if bool_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
return None return None
x_shape = x.shape new_idxs = []
raveled_x = x.reshape( for i, idx in enumerate(idxs):
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :]) if i in bool_pos:
) new_idxs.extend(idx.nonzero())
else:
raveled_bool_idx = bool_idx.ravel() new_idxs.append(idx)
new_idxs = list(idxs)
new_idxs[bool_idx_pos] = raveled_bool_idx
if isinstance(node.op, AdvancedSubtensor): if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(raveled_x, *new_idxs) new_out = node.op(x, *new_idxs)
else: else:
# The dimensions of y that correspond to the boolean indices new_out = node.op(x, y, *new_idxs)
# must already be raveled in the original graph, so we don't need to do anything to it
new_out = node.op(raveled_x, y, *new_idxs)
# But we must reshape the output to math the original shape
new_out = new_out.reshape(x_shape)
return [copy_stack_trace(node.outputs[0], new_out)] return [copy_stack_trace(node.outputs[0], new_out)]
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
supported by Numba or by our specialized dispatchers
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
It also handles multiple integer indices, but only if they don't broadcast
x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
"""
op = node.op
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)
if is_inc_subtensor:
x, y, *idxs = node.inputs
# Inc/SetSubtensor is harder to reason about due to y
# We get out if it's broadcasting or if the advanced indices are non-consecutive
if non_consecutive_adv_indexing or (
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
):
return None
else:
x, *idxs = node.inputs
if any(
(
(isinstance(idx.type, TensorType) and idx.type.dtype == "bool")
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
):
# Get out if there are any other advanced indices or np.newaxis
return None
int_idxs_and_pos = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
]
if not int_idxs_and_pos:
return None
int_idxs_pos, int_idxs = zip(
*int_idxs_and_pos, strict=False
) # strict=False because by definition it's true
first_int_idx_pos = int_idxs_pos[0]
first_int_idx = int_idxs[0]
first_int_idx_bcast = first_int_idx.type.broadcastable
if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
# We don't have a view-only broadcasting operation
# Explicitly broadcasting the indices can incur a memory / copy overhead
return None
int_idxs_ndim = len(first_int_idx_bcast)
if (
int_idxs_ndim == 0
): # This should be a basic indexing operation, rewrite elsewhere
return None
int_idxs_need_raveling = int_idxs_ndim > 1
if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
return None
# Reorder non-consecutive indices
if non_consecutive_adv_indexing:
assert not is_inc_subtensor # Sanity check that we got out if this was the case
# This case works as if all the advanced indices were on the front
transposition = list(int_idxs_pos) + [
i for i in range(len(idxs)) if i not in int_idxs_pos
]
idxs = tuple(idxs[a] for a in transposition)
x = x.transpose(transposition)
first_int_idx_pos = 0
del int_idxs_pos # Make sure they are not wrongly used
# Ravel multidimensional indices
if int_idxs_need_raveling:
idxs = list(idxs)
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
idxs[idx_pos] = int_idx.ravel()
# Index with reordered and/or raveled indices
new_subtensor = x[tuple(idxs)]
if is_inc_subtensor:
y_shape = tuple(y.shape)
y_raveled_shape = (
*y_shape[:first_int_idx_pos],
-1,
*y_shape[first_int_idx_pos + int_idxs_ndim :],
)
y_raveled = y.reshape(y_raveled_shape)
new_out = inc_subtensor(
new_subtensor,
y_raveled,
set_instead_of_inc=op.set_instead_of_inc,
ignore_duplicates=op.ignore_duplicates,
inplace=op.inplace,
)
else:
# Unravel advanced indexing dimensions
raveled_shape = tuple(new_subtensor.shape)
unraveled_shape = (
*raveled_shape[:first_int_idx_pos],
*first_int_idx.shape,
*raveled_shape[first_int_idx_pos + 1 :],
)
new_out = new_subtensor.reshape(unraveled_shape)
return [copy_stack_trace(node.outputs[0], new_out)]
optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx,
"numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)
optdb["specialize"].register( optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__, bool_idx_to_nonzero.__name__,
ravel_multidimensional_int_idx, bool_idx_to_nonzero,
"numba", "numba",
"shape_unsafe", # It can mask invalid mask sizes
use_db_name_as_tag=False, # Not included if only "specialize" is requested use_db_name_as_tag=False, # Not included if only "specialize" is requested
) )
......
...@@ -109,117 +109,95 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -109,117 +109,95 @@ def test_AdvancedSubtensor1_out_of_bounds():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices, objmode_needed", "x, indices",
[ [
# Single vector indexing (supported natively by Numba) # Single vector indexing
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(0, [1, 2, 2, 3]), (0, [1, 2, 2, 3]),
False,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(np.array([True, False, False])), (np.array([True, False, False])),
False,
), ),
# Single multidimensional indexing (supported after specialization rewrites) # Single multidimensional indexing
( (
as_tensor(np.arange(3 * 3).reshape((3, 3))), as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(int)), (np.eye(3).astype(int)),
False,
), ),
( (
as_tensor(np.arange(3 * 3).reshape((3, 3))), as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(bool)), (np.eye(3).astype(bool)),
False,
), ),
( (
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))), as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(int)), (np.eye(3).astype(int)),
False,
), ),
( (
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))), as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(bool)), (np.eye(3).astype(bool)),
False,
), ),
( (
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))), as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(int)), (slice(2, None), np.eye(3).astype(int)),
False,
), ),
( (
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))), as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(bool)), (slice(2, None), np.eye(3).astype(bool)),
False,
), ),
# Multiple vector indexing (supported by our dispatcher) # Multiple vector indexing
( (
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], [2, 3]), ([1, 2], [2, 3]),
False,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None), [1, 2], [3, 4]), (slice(None), [1, 2], [3, 4]),
False,
), ),
( (
as_tensor(np.arange(3 * 5 * 7).reshape((3, 5, 7))), as_tensor(np.arange(3 * 5 * 7).reshape((3, 5, 7))),
([1, 2], [3, 4], [5, 6]), ([1, 2], [3, 4], [5, 6]),
False,
), ),
# Non-consecutive vector indexing, supported by our dispatcher after rewriting # Non-consecutive vector indexing
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]), ([1, 2], slice(None), [3, 4]),
False,
), ),
# Multiple multidimensional integer indexing (supported by our dispatcher) # Multiple multidimensional integer indexing
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [[0, 0], [0, 0]]), ([[1, 2], [2, 1]], [[0, 0], [0, 0]]),
False,
), ),
( (
as_tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))), as_tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))),
(slice(None), [[1, 2], [2, 1]], slice(None), [[0, 0], [0, 0]]), (slice(None), [[1, 2], [2, 1]], slice(None), [[0, 0], [0, 0]]),
False,
), ),
# Multiple multidimensional indexing with broadcasting, only supported in obj mode # Multiple multidimensional indexing with broadcasting
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [0, 0]), ([[1, 2], [2, 1]], [0, 0]),
True,
), ),
# multiple multidimensional integer indexing mixed with basic indexing, only supported in obj mode # multiple multidimensional integer indexing mixed with basic indexing
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]), ([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
True,
), ),
], ],
) )
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed @pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
def test_AdvancedSubtensor(x, indices, objmode_needed): def test_AdvancedSubtensor(x, indices):
"""Test NumPy's advanced indexing in more than one dimension.""" """Test NumPy's advanced indexing in more than one dimension."""
x_pt = x.type() x_pt = x.type()
out_pt = x_pt[indices] out_pt = x_pt[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor) assert isinstance(out_pt.owner.op, AdvancedSubtensor)
with ( compare_numba_and_py(
pytest.warns( [x_pt],
UserWarning, [out_pt],
match="Numba will use object mode to run AdvancedSubtensor's perform method", [x.data],
) # Specialize allows running boolean indexing without falling back to object mode
if objmode_needed # Thanks to bool_idx_to_nonzero rewrite
else contextlib.nullcontext() numba_mode=numba_mode.including("specialize"),
): )
compare_numba_and_py(
[x_pt],
[out_pt],
[x.data],
numba_mode=numba_mode.including("specialize"),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -323,7 +301,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -323,7 +301,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode", "x, y, indices, duplicate_indices, duplicate_indices_require_obj_mode",
[ [
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -331,7 +309,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -331,7 +309,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector index (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector index
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -343,7 +320,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -343,7 +320,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
), # Mixed basic and broadcasted vector idx ), # Mixed basic and broadcasted vector idx
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -351,7 +327,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -351,7 +327,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector idx (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector idx
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -359,7 +334,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -359,7 +334,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values (0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
True, True,
False, False,
True,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -367,21 +341,11 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -367,21 +341,11 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values (0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
True, True,
False, False,
True,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(1 * 4 * 5).reshape(1, 4, 5), -np.arange(1 * 4 * 5).reshape(1, 4, 5),
(np.array([True, False, False])), # Broadcasted boolean index (np.array([True, False, False])), # Broadcasted boolean index
False, # It shouldn't matter what we set this to, boolean indices cannot be duplicate
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
(np.array([True, False, False])), # Broadcasted boolean index
True, # It shouldn't matter what we set this to, boolean indices cannot be duplicate
False, False,
False, False,
), ),
...@@ -391,7 +355,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -391,7 +355,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(np.eye(3).astype(bool)), # Boolean index (np.eye(3).astype(bool)), # Boolean index
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 3 * 5).reshape((3, 3, 5)), np.arange(3 * 3 * 5).reshape((3, 3, 5)),
...@@ -402,7 +365,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -402,7 +365,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
), # Boolean index, mixed with basic index ), # Boolean index, mixed with basic index
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -410,7 +372,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -410,7 +372,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([1, 2], [2, 3]), # 2 vector indices ([1, 2], [2, 3]), # 2 vector indices
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -418,7 +379,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -418,7 +379,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(None), [1, 2], [2, 3]), # 2 vector indices (slice(None), [1, 2], [2, 3]), # 2 vector indices
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 4 * 6).reshape((3, 4, 6)), np.arange(3 * 4 * 6).reshape((3, 4, 6)),
...@@ -426,7 +386,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -426,7 +386,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([1, 2], [2, 3], [4, 5]), # 3 vector indices ([1, 2], [2, 3], [4, 5]), # 3 vector indices
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -434,15 +393,13 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -434,15 +393,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([1, 2], [2, 3]), # 2 vector indices ([1, 2], [2, 3]), # 2 vector indices
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 4)), rng.poisson(size=(2, 4)),
([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices ([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices
False, False,
True, False,
True,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -453,8 +410,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -453,8 +410,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
[3, 4], [3, 4],
), # Mixed double vector index and basic index ), # Mixed double vector index and basic index
False, False,
True, False,
True,
), ),
( (
np.arange(5), np.arange(5),
...@@ -462,7 +418,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -462,7 +418,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([[1, 2], [2, 3]]), # matrix index ([[1, 2], [2, 3]]), # matrix index
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 5).reshape((3, 5)), np.arange(3 * 5).reshape((3, 5)),
...@@ -470,23 +425,20 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -470,23 +425,20 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(1, 3), [[1, 2], [2, 3]]), # matrix index, mixed with basic index (slice(1, 3), [[1, 2], [2, 3]]), # matrix index, mixed with basic index
False, False,
False, False,
False,
), ),
( (
np.arange(3 * 5).reshape((3, 5)), np.arange(3 * 5).reshape((3, 5)),
rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts
(slice(1, 3), [[1, 2], [2, 3]]), (slice(1, 3), [[1, 2], [2, 3]]),
False, False,
True, False,
True,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 5)), rng.poisson(size=(2, 5)),
([1, 1], [2, 2]), # Repeated indices ([1, 1], [2, 2]), # Repeated indices
True, True,
False, True,
False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
...@@ -494,7 +446,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -494,7 +446,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(slice(None), [[1, 2], [2, 1]], [[2, 3], [0, 0]]), # 2 matrix indices (slice(None), [[1, 2], [2, 1]], [[2, 3], [0, 0]]), # 2 matrix indices
False, False,
False, False,
False,
), ),
], ],
) )
...@@ -505,8 +456,7 @@ def test_AdvancedIncSubtensor( ...@@ -505,8 +456,7 @@ def test_AdvancedIncSubtensor(
y, y,
indices, indices,
duplicate_indices, duplicate_indices,
set_requires_objmode, duplicate_indices_require_obj_mode,
inc_requires_objmode,
inplace, inplace,
): ):
# Need rewrite to support certain forms of advanced indexing without object mode # Need rewrite to support certain forms of advanced indexing without object mode
...@@ -518,17 +468,9 @@ def test_AdvancedIncSubtensor( ...@@ -518,17 +468,9 @@ def test_AdvancedIncSubtensor(
out_pt = set_subtensor(x_pt[indices], y_pt, inplace=inplace) out_pt = set_subtensor(x_pt[indices], y_pt, inplace=inplace)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
with ( fn, _ = compare_numba_and_py(
pytest.warns( [x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
UserWarning, )
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
)
if set_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace: if inplace:
# Test updates inplace # Test updates inplace
...@@ -536,23 +478,58 @@ def test_AdvancedIncSubtensor( ...@@ -536,23 +478,58 @@ def test_AdvancedIncSubtensor(
fn(x, y + 1) fn(x, y + 1)
assert not np.all(x == x_orig) assert not np.all(x == x_orig)
out_pt = inc_subtensor( out_pt = inc_subtensor(x_pt[indices], y_pt, inplace=inplace)
x_pt[indices], y_pt, ignore_duplicates=not duplicate_indices, inplace=inplace
)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
with (
pytest.warns( fn, _ = compare_numba_and_py(
UserWarning, [x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
match="Numba will use object mode to run AdvancedIncSubtensor's perform method", )
)
if inc_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace: if inplace:
# Test updates inplace # Test updates inplace
x_orig = x.copy() x_orig = x.copy()
fn(x, y) fn(x, y)
assert not np.all(x == x_orig) assert not np.all(x == x_orig)
if duplicate_indices:
# If inc_subtensor is called with `ignore_duplicates=True`, and it's not one of the cases supported by Numba
# We have to fall back to obj_mode
out_pt = inc_subtensor(
x_pt[indices], y_pt, inplace=inplace, ignore_duplicates=True
)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedIncSubtensor's perform method",
)
if duplicate_indices_require_obj_mode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(
[x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace
)
if inplace:
# Test updates inplace
x_orig = x.copy()
fn(x, y)
assert not np.all(x == x_orig)
def test_advanced_indexing_with_newaxis_fallback_obj_mode():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x = pt.matrix("x")
out = x[None, [0, 1, 2], [0, 1, 2]]
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
out = x[None, [0, 1, 2], [0, 1, 2]].inc(5)
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedIncSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论