提交 85506229 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Support single multidimensional indexing in Numba via rewrites

上级 d5122713
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import pytensor import pytensor
import pytensor.scalar.basic as ps import pytensor.scalar.basic as ps
from pytensor import compile from pytensor import compile
from pytensor.compile import optdb
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
WalkingGraphRewriter, WalkingGraphRewriter,
...@@ -1932,3 +1933,111 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node): ...@@ -1932,3 +1933,111 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
copy_stack_trace(node.outputs, new_out) copy_stack_trace(node.outputs, new_out)
return new_out return new_out
@node_rewriter(tracks=[AdvancedSubtensor])
def ravel_multidimensional_bool_idx(fgraph, node):
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
"""
x, *idxs = node.inputs
if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")
for idx in idxs
):
# Get out if there are any other advanced indexes
return None
bool_idxs = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
]
if len(bool_idxs) != 1:
# Get out if there are no or multiple boolean idxs
return None
[(bool_idx_pos, bool_idx)] = bool_idxs
bool_idx_ndim = bool_idx.type.ndim
if bool_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
return None
x_shape = x.shape
raveled_x = x.reshape(
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :])
)
raveled_bool_idx = bool_idx.ravel()
new_idxs = list(idxs)
new_idxs[bool_idx_pos] = raveled_bool_idx
return [raveled_x[tuple(new_idxs)]]
@node_rewriter(tracks=[AdvancedSubtensor])
def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
"""
x, *idxs = node.inputs
if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool")
for idx in idxs
):
# Get out if there are any other advanced indexes
return None
int_idxs = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int"))
]
if len(int_idxs) != 1:
# Get out if there are no or multiple integer idxs
return None
[(int_idx_pos, int_idx)] = int_idxs
if int_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
return None
raveled_int_idx = int_idx.ravel()
new_idxs = list(idxs)
new_idxs[int_idx_pos] = raveled_int_idx
raveled_subtensor = x[tuple(new_idxs)]
# Reshape into correct shape
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
raveled_shape = raveled_subtensor.shape
unraveled_shape = (
*raveled_shape[:int_idx_pos],
*int_idx.shape,
*raveled_shape[int_idx_pos + 1 :],
)
return [raveled_subtensor.reshape(unraveled_shape)]
optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx,
"numba",
)
optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__,
ravel_multidimensional_int_idx,
"numba",
)
...@@ -19,7 +19,7 @@ from pytensor.tensor.subtensor import ( ...@@ -19,7 +19,7 @@ from pytensor.tensor.subtensor import (
inc_subtensor, inc_subtensor,
set_subtensor, set_subtensor,
) )
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
rng = np.random.default_rng(sum(map(ord, "Numba subtensors"))) rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
...@@ -74,6 +74,7 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -74,6 +74,7 @@ def test_AdvancedSubtensor1_out_of_bounds():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices, objmode_needed", "x, indices, objmode_needed",
[ [
# Single vector indexing (supported natively by Numba)
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(0, [1, 2, 2, 3]), (0, [1, 2, 2, 3]),
...@@ -84,25 +85,63 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -84,25 +85,63 @@ def test_AdvancedSubtensor1_out_of_bounds():
(np.array([True, False, False])), (np.array([True, False, False])),
False, False,
), ),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
# Single multidimensional indexing (supported after specialization rewrites)
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(int)),
False,
),
( (
as_tensor(np.arange(3 * 3).reshape((3, 3))), as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(bool)), (np.eye(3).astype(bool)),
False,
),
(
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(bool)),
False,
),
(
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(bool)),
False,
),
# Multiple advanced indexing, only supported in obj mode
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None), [1, 2], [3, 4]),
True, True,
), ),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
( (
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]), ([1, 2], slice(None), [3, 4]),
True, True,
), ),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [0, 0]),
True,
),
], ],
) )
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error")
def test_AdvancedSubtensor(x, indices, objmode_needed): def test_AdvancedSubtensor(x, indices, objmode_needed):
"""Test NumPy's advanced indexing in more than one dimension.""" """Test NumPy's advanced indexing in more than one dimension."""
out_pt = x[indices] x_pt = x.type()
out_pt = x_pt[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor) assert isinstance(out_pt.owner.op, AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt]) out_fg = FunctionGraph([x_pt], [out_pt])
with ( with (
pytest.warns( pytest.warns(
UserWarning, UserWarning,
...@@ -111,7 +150,11 @@ def test_AdvancedSubtensor(x, indices, objmode_needed): ...@@ -111,7 +150,11 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
if objmode_needed if objmode_needed
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
compare_numba_and_py(out_fg, []) compare_numba_and_py(
out_fg,
[x.data],
numba_mode=numba_mode.including("specialize"),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论