提交 bd281be6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Support consecutive integer vector indexing in Numba backend

上级 b8356ff9
...@@ -5,6 +5,7 @@ from pytensor.link.numba.dispatch import numba_funcify ...@@ -5,6 +5,7 @@ from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor import TensorType from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -13,6 +14,7 @@ from pytensor.tensor.subtensor import ( ...@@ -13,6 +14,7 @@ from pytensor.tensor.subtensor import (
IncSubtensor, IncSubtensor,
Subtensor, Subtensor,
) )
from pytensor.tensor.type_other import NoneTypeT, SliceType
@numba_funcify.register(Subtensor) @numba_funcify.register(Subtensor)
...@@ -104,18 +106,73 @@ def {function_name}({", ".join(input_names)}): ...@@ -104,18 +106,73 @@ def {function_name}({", ".join(input_names)}):
@numba_funcify.register(AdvancedSubtensor) @numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedIncSubtensor) @numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs): def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:] if isinstance(op, AdvancedSubtensor):
adv_idxs_dims = [ x, y, idxs = node.inputs[0], None, node.inputs[1:]
idx.type.ndim else:
x, y, *idxs = node.inputs
basic_idxs = [
idx
for idx in idxs for idx in idxs
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0) if (
isinstance(idx.type, NoneTypeT)
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
)
]
adv_idxs = [
{
"axis": i,
"dtype": idx.type.dtype,
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
for i, idx in enumerate(idxs)
if isinstance(idx.type, TensorType)
] ]
# Special case for consecutive consecutive vector indices
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True
return False
# Special implementation for consecutive integer vector indices
if (
not basic_idxs
and len(adv_idxs) >= 2
# Must be integer vectors
# Todo: we could allow shape=(1,) if this is the shape of x
and all(
(adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool")
for adv_idx in adv_idxs
)
# Must be consecutive
and not op.non_contiguous_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted
and (
y is None
or not broadcasted_to(
y.type.broadcastable,
(
x.type.broadcastable[: adv_idxs[0]["axis"]]
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
),
)
)
):
return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs)
# Other cases not natively supported by Numba (fallback to obj-mode)
if ( if (
# Numba does not support indexes with more than one dimension # Numba does not support indexes with more than one dimension
any(idx["ndim"] > 1 for idx in adv_idxs)
# Nor multiple vector indexes # Nor multiple vector indexes
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1) or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1
# The default index implementation does not handle duplicate indices correctly # The default PyTensor implementation does not handle duplicate indices correctly
or ( or (
isinstance(op, AdvancedIncSubtensor) isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc and not op.set_instead_of_inc
...@@ -124,9 +181,91 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -124,9 +181,91 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
): ):
return generate_fallback_impl(op, node, **kwargs) return generate_fallback_impl(op, node, **kwargs)
# What's left should all be supported natively by numba
return numba_funcify_default_subtensor(op, node, **kwargs) return numba_funcify_default_subtensor(op, node, **kwargs)
def numba_funcify_multiple_integer_vector_indexing(
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
):
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
if isinstance(op, AdvancedSubtensor):
y, idxs = None, node.inputs[1:]
else:
y, *idxs = node.inputs[1:]
first_axis = next(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
)
try:
after_last_axis = next(
i
for i, idx in enumerate(idxs[first_axis:], start=first_axis)
if not isinstance(idx.type, TensorType)
)
except StopIteration:
after_last_axis = len(idxs)
if isinstance(op, AdvancedSubtensor):
@numba_njit
def advanced_subtensor_multiple_vector(x, *idxs):
none_slices = idxs[:first_axis]
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
idx_shape = vec_idxs[0].shape
shape_bef = x_shape[:first_axis]
shape_aft = x_shape[after_last_axis:]
out_shape = (*shape_bef, *idx_shape, *shape_aft)
out_buffer = np.empty(out_shape, dtype=x.dtype)
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
return out_buffer
return advanced_subtensor_multiple_vector
elif op.set_instead_of_inc:
inplace = op.inplace
@numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
if inplace:
out = x
else:
out = x.copy()
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out
return advanced_set_subtensor_multiple_vector
else:
inplace = op.inplace
@numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
if inplace:
out = x
else:
out = x.copy()
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out
return advanced_inc_subtensor_multiple_vector
@numba_funcify.register(AdvancedIncSubtensor1) @numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace inplace = op.inplace
......
...@@ -2937,6 +2937,31 @@ class AdvancedIncSubtensor(Op): ...@@ -2937,6 +2937,31 @@ class AdvancedIncSubtensor(Op):
gy = _sum_grad_over_bcasted_dims(y, gy) gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()() for _ in idxs] return [gx, gy] + [DisconnectedType()() for _ in idxs]
@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-contiguous,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
Parameters
----------
node : Apply
The node of the AdvancedSubtensor operation.
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
"""
_, _, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)
advanced_inc_subtensor = AdvancedIncSubtensor() advanced_inc_subtensor = AdvancedIncSubtensor()
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True) advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
......
...@@ -228,9 +228,11 @@ def compare_numba_and_py( ...@@ -228,9 +228,11 @@ def compare_numba_and_py(
fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]], fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]],
inputs: Sequence["TensorLike"], inputs: Sequence["TensorLike"],
assert_fn: Callable | None = None, assert_fn: Callable | None = None,
*,
numba_mode=numba_mode, numba_mode=numba_mode,
py_mode=py_mode, py_mode=py_mode,
updates=None, updates=None,
inplace: bool = False,
eval_obj_mode: bool = True, eval_obj_mode: bool = True,
) -> tuple[Callable, Any]: ) -> tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality """Function to compare python graph output and Numba compiled output for testing equality
...@@ -276,7 +278,14 @@ def compare_numba_and_py( ...@@ -276,7 +278,14 @@ def compare_numba_and_py(
pytensor_py_fn = function( pytensor_py_fn = function(
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
) )
py_res = pytensor_py_fn(*inputs)
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
py_res = pytensor_py_fn(*test_inputs)
# Get some coverage (and catch errors in python mode before unreadable numba ones)
if eval_obj_mode:
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode)
pytensor_numba_fn = function( pytensor_numba_fn = function(
fn_inputs, fn_inputs,
...@@ -285,11 +294,9 @@ def compare_numba_and_py( ...@@ -285,11 +294,9 @@ def compare_numba_and_py(
accept_inplace=True, accept_inplace=True,
updates=updates, updates=updates,
) )
numba_res = pytensor_numba_fn(*inputs)
# Get some coverage test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
if eval_obj_mode: numba_res = pytensor_numba_fn(*test_inputs)
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
if len(fn_outputs) > 1: if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res, strict=True): for j, p in zip(numba_res, py_res, strict=True):
......
...@@ -85,7 +85,11 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -85,7 +85,11 @@ def test_AdvancedSubtensor1_out_of_bounds():
(np.array([True, False, False])), (np.array([True, False, False])),
False, False,
), ),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True), (
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], [2, 3]),
False,
),
# Single multidimensional indexing (supported after specialization rewrites) # Single multidimensional indexing (supported after specialization rewrites)
( (
as_tensor(np.arange(3 * 3).reshape((3, 3))), as_tensor(np.arange(3 * 3).reshape((3, 3))),
...@@ -117,17 +121,23 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -117,17 +121,23 @@ def test_AdvancedSubtensor1_out_of_bounds():
(slice(2, None), np.eye(3).astype(bool)), (slice(2, None), np.eye(3).astype(bool)),
False, False,
), ),
# Multiple advanced indexing, only supported in obj mode
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None), [1, 2], [3, 4]), (slice(None), [1, 2], [3, 4]),
True, False,
),
(
as_tensor(np.arange(3 * 5 * 7).reshape((3, 5, 7))),
([1, 2], [3, 4], [5, 6]),
False,
), ),
# Non-contiguous vector indexing, only supported in obj mode
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]), ([1, 2], slice(None), [3, 4]),
True, True,
), ),
# >1d vector indexing, only supported in obj mode
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [0, 0]), ([[1, 2], [2, 1]], [0, 0]),
...@@ -135,7 +145,7 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -135,7 +145,7 @@ def test_AdvancedSubtensor1_out_of_bounds():
), ),
], ],
) )
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
def test_AdvancedSubtensor(x, indices, objmode_needed): def test_AdvancedSubtensor(x, indices, objmode_needed):
"""Test NumPy's advanced indexing in more than one dimension.""" """Test NumPy's advanced indexing in more than one dimension."""
x_pt = x.type() x_pt = x.type()
...@@ -268,94 +278,151 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -268,94 +278,151 @@ def test_AdvancedIncSubtensor1(x, y, indices):
"x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode", "x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode",
[ [
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(3 * 5).reshape(3, 5), -np.arange(3 * 5).reshape(3, 5),
(slice(None, None, 2), [1, 2, 3]), (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector index
False, False,
False, False,
False, False,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-99, np.array(-99), # Broadcasted value
(slice(None, None, 2), [1, 2, 3], -1), (
slice(None, None, 2),
[1, 2, 3],
-1,
), # Mixed basic and broadcasted vector idx
False, False,
False, False,
False, False,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-99, # Broadcasted value np.array(-99), # Broadcasted value
(slice(None, None, 2), [1, 2, 3]), (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector idx
False, False,
False, False,
False, False,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(4 * 5).reshape(4, 5), -np.arange(4 * 5).reshape(4, 5),
(0, [1, 2, 2, 3]), (0, [1, 2, 2, 3]), # Broadcasted vector index
True, True,
False, False,
True, True,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
[-99], # Broadcsasted value np.array([-99]), # Broadcasted value
(0, [1, 2, 2, 3]), (0, [1, 2, 2, 3]), # Broadcasted vector index
True, True,
False, False,
True, True,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
-np.arange(1 * 4 * 5).reshape(1, 4, 5), -np.arange(1 * 4 * 5).reshape(1, 4, 5),
(np.array([True, False, False])), (np.array([True, False, False])), # Broadcasted boolean index
False, False,
False, False,
False, False,
), ),
( (
as_tensor(np.arange(3 * 3).reshape((3, 3))), np.arange(3 * 3).reshape((3, 3)),
-np.arange(3), -np.arange(3),
(np.eye(3).astype(bool)), (np.eye(3).astype(bool)), # Boolean index
False, False,
True, True,
True, True,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
as_tensor(rng.poisson(size=(2, 5))), rng.poisson(size=(2, 5)),
([1, 2], [2, 3]), ([1, 2], [2, 3]), # 2 vector indices
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(3, 2)),
(slice(None), [1, 2], [2, 3]), # 2 vector indices
False,
False,
False,
),
(
np.arange(3 * 4 * 6).reshape((3, 4, 6)),
rng.poisson(size=(2,)),
([1, 2], [2, 3], [4, 5]), # 3 vector indices
False,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
np.array(-99), # Broadcasted value
([1, 2], [2, 3]), # 2 vector indices
False, False,
True, True,
True, True,
), ),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
as_tensor(rng.poisson(size=(2, 4))), rng.poisson(size=(2, 4)),
([1, 2], slice(None), [3, 4]), ([1, 2], slice(None), [3, 4]), # Non-contiguous vector indices
False, False,
True, True,
True, True,
), ),
pytest.param( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
as_tensor(rng.poisson(size=(2, 5))), rng.poisson(size=(2, 2)),
([1, 1], [2, 2]), (
slice(1, None),
[1, 2],
[3, 4],
), # Mixed double vector index and basic index
False,
True,
True,
),
(
np.arange(5),
rng.poisson(size=(2, 2)),
([[1, 2], [2, 3]]), # matrix indices
False, False,
True, True,
True, True,
), ),
pytest.param(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
rng.poisson(size=(2, 5)),
([1, 1], [2, 2]), # Repeated indices
True,
False,
False,
),
], ],
) )
@pytest.mark.filterwarnings("error") @pytest.mark.parametrize("inplace", (False, True))
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
def test_AdvancedIncSubtensor( def test_AdvancedIncSubtensor(
x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode x,
y,
indices,
duplicate_indices,
set_requires_objmode,
inc_requires_objmode,
inplace,
): ):
out_pt = set_subtensor(x[indices], y) x_pt = pt.as_tensor(x).type("x")
y_pt = pt.as_tensor(y).type("y")
out_pt = set_subtensor(x_pt[indices], y_pt, inplace=inplace)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
with ( with (
pytest.warns( pytest.warns(
...@@ -365,11 +432,18 @@ def test_AdvancedIncSubtensor( ...@@ -365,11 +432,18 @@ def test_AdvancedIncSubtensor(
if set_requires_objmode if set_requires_objmode
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
compare_numba_and_py(out_fg, []) fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y])
if inplace:
# Test updates inplace
x_orig = x.copy()
fn(x, y + 1)
assert not np.all(x == x_orig)
out_pt = inc_subtensor(x[indices], y, ignore_duplicates=not duplicate_indices) out_pt = inc_subtensor(
x_pt[indices], y_pt, ignore_duplicates=not duplicate_indices, inplace=inplace
)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
with ( with (
pytest.warns( pytest.warns(
UserWarning, UserWarning,
...@@ -378,21 +452,9 @@ def test_AdvancedIncSubtensor( ...@@ -378,21 +452,9 @@ def test_AdvancedIncSubtensor(
if inc_requires_objmode if inc_requires_objmode
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
compare_numba_and_py(out_fg, []) fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y])
if inplace:
x_pt = x.type() # Test updates inplace
out_pt = set_subtensor(x_pt[indices], y) x_orig = x.copy()
# Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just fn(x, y)
# hack it on here assert not np.all(x == x_orig)
out_pt.owner.op.inplace = True
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
)
if set_requires_objmode
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, [x.data])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论