提交 85506229 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Support single multidimensional indexing in Numba via rewrites

上级 d5122713
......@@ -7,6 +7,7 @@ import numpy as np
import pytensor
import pytensor.scalar.basic as ps
from pytensor import compile
from pytensor.compile import optdb
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import (
WalkingGraphRewriter,
......@@ -1932,3 +1933,111 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
copy_stack_trace(node.outputs, new_out)
return new_out
@node_rewriter(tracks=[AdvancedSubtensor])
def ravel_multidimensional_bool_idx(fgraph, node):
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
"""
x, *idxs = node.inputs
if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")
for idx in idxs
):
# Get out if there are any other advanced indexes
return None
bool_idxs = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
]
if len(bool_idxs) != 1:
# Get out if there are no or multiple boolean idxs
return None
[(bool_idx_pos, bool_idx)] = bool_idxs
bool_idx_ndim = bool_idx.type.ndim
if bool_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
return None
x_shape = x.shape
raveled_x = x.reshape(
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :])
)
raveled_bool_idx = bool_idx.ravel()
new_idxs = list(idxs)
new_idxs[bool_idx_pos] = raveled_bool_idx
return [raveled_x[tuple(new_idxs)]]
@node_rewriter(tracks=[AdvancedSubtensor])
def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
"""
x, *idxs = node.inputs
if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool")
for idx in idxs
):
# Get out if there are any other advanced indexes
return None
int_idxs = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int"))
]
if len(int_idxs) != 1:
# Get out if there are no or multiple integer idxs
return None
[(int_idx_pos, int_idx)] = int_idxs
if int_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
return None
raveled_int_idx = int_idx.ravel()
new_idxs = list(idxs)
new_idxs[int_idx_pos] = raveled_int_idx
raveled_subtensor = x[tuple(new_idxs)]
# Reshape into correct shape
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
raveled_shape = raveled_subtensor.shape
unraveled_shape = (
*raveled_shape[:int_idx_pos],
*int_idx.shape,
*raveled_shape[int_idx_pos + 1 :],
)
return [raveled_subtensor.reshape(unraveled_shape)]
optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx,
"numba",
)
optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__,
ravel_multidimensional_int_idx,
"numba",
)
......@@ -19,7 +19,7 @@ from pytensor.tensor.subtensor import (
inc_subtensor,
set_subtensor,
)
from tests.link.numba.test_basic import compare_numba_and_py
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
......@@ -74,6 +74,7 @@ def test_AdvancedSubtensor1_out_of_bounds():
@pytest.mark.parametrize(
"x, indices, objmode_needed",
[
# Single vector indexing (supported natively by Numba)
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(0, [1, 2, 2, 3]),
......@@ -84,25 +85,63 @@ def test_AdvancedSubtensor1_out_of_bounds():
(np.array([True, False, False])),
False,
),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
# Single multidimensional indexing (supported after specialization rewrites)
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(bool)),
False,
),
(
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(bool)),
False,
),
(
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(bool)),
False,
),
# Multiple advanced indexing, only supported in obj mode
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None), [1, 2], [3, 4]),
True,
),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]),
True,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [0, 0]),
True,
),
],
)
@pytest.mark.filterwarnings("error")
def test_AdvancedSubtensor(x, indices, objmode_needed):
"""Test NumPy's advanced indexing in more than one dimension."""
out_pt = x[indices]
x_pt = x.type()
out_pt = x_pt[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
out_fg = FunctionGraph([x_pt], [out_pt])
with (
pytest.warns(
UserWarning,
......@@ -111,7 +150,11 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
if objmode_needed
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, [])
compare_numba_and_py(
out_fg,
[x.data],
numba_mode=numba_mode.including("specialize"),
)
@pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论