提交 82a57574 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Luciano Paz

Implement vectorize_node dispatch for some forms of AdvancedSubtensor

上级 56637af8
...@@ -47,6 +47,7 @@ from pytensor.tensor.type import ( ...@@ -47,6 +47,7 @@ from pytensor.tensor.type import (
zscalar, zscalar,
) )
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice
from pytensor.tensor.variable import TensorVariable
_logger = logging.getLogger("pytensor.tensor.subtensor") _logger = logging.getLogger("pytensor.tensor.subtensor")
...@@ -473,6 +474,13 @@ def group_indices(indices): ...@@ -473,6 +474,13 @@ def group_indices(indices):
return idx_groups return idx_groups
def _non_contiguous_adv_indexing(indices) -> bool:
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing)."""
idx_groups = group_indices(indices)
# This means that there are at least two groups of advanced indexing separated by basic indexing
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])
def indexed_result_shape(array_shape, indices, indices_are_shapes=False): def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
"""Compute the symbolic shape resulting from `a[indices]` for `a.shape == array_shape`. """Compute the symbolic shape resulting from `a[indices]` for `a.shape == array_shape`.
...@@ -497,8 +505,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ...@@ -497,8 +505,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape)) remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
idx_groups = group_indices(indices) idx_groups = group_indices(indices)
if len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0]): if _non_contiguous_adv_indexing(indices):
# This means that there are at least two groups of advanced indexing separated by basic indexing
# In this case NumPy places the advanced index groups in the front of the array # In this case NumPy places the advanced index groups in the front of the array
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing # https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
idx_groups = sorted(idx_groups, key=lambda x: x[0]) idx_groups = sorted(idx_groups, key=lambda x: x[0])
...@@ -2682,10 +2689,68 @@ class AdvancedSubtensor(Op): ...@@ -2682,10 +2689,68 @@ class AdvancedSubtensor(Op):
rest rest
) )
@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-contiguous,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
Parameters
----------
node : Apply
The node of the AdvancedSubtensor operation.
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
"""
_, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)
advanced_subtensor = AdvancedSubtensor() advanced_subtensor = AdvancedSubtensor()
@_vectorize_node.register(AdvancedSubtensor)
def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
x, *idxs = node.inputs
batch_x, *batch_idxs = batch_inputs
x_is_batched = x.type.ndim < batch_x.type.ndim
idxs_are_batched = any(
batch_idx.type.ndim > idx.type.ndim
for batch_idx, idx in zip(batch_idxs, idxs)
if isinstance(batch_idx, TensorVariable)
)
if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)):
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
# which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
# Blockwise doesn't accept None or Slices types so we raise informative error here
# TODO: Implement these internally, so Blockwise is always a safe fallback
if any(not isinstance(idx, TensorVariable) for idx in idxs):
raise NotImplementedError(
"Vectorized AdvancedSubtensor with batched indexes or non-contiguous advanced indexing "
"and slices or newaxis is currently not supported."
)
else:
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
# Otherwise we just need to add None slices for every new batch dim
x_batch_ndim = batch_x.type.ndim - x.type.ndim
empty_slices = (slice(None),) * x_batch_ndim
return op.make_node(batch_x, *empty_slices, *batch_idxs)
class AdvancedIncSubtensor(Op): class AdvancedIncSubtensor(Op):
"""Increments a subtensor using advanced indexing.""" """Increments a subtensor using advanced indexing."""
......
...@@ -2751,3 +2751,88 @@ def test_vectorize_subtensor_without_batch_indices(): ...@@ -2751,3 +2751,88 @@ def test_vectorize_subtensor_without_batch_indices():
vectorize_pt(x_test, start_test), vectorize_pt(x_test, start_test),
vectorize_np(x_test, start_test), vectorize_np(x_test, start_test),
) )
@pytest.mark.parametrize(
"core_idx_fn, signature, x_shape, idx_shape, uses_blockwise",
[
# Core case
((lambda x, idx: x[:, idx, :]), "(7,5,3),(2)->(7,2,3)", (7, 5, 3), (2,), False),
# Batched x, core idx
(
(lambda x, idx: x[:, idx, :]),
"(7,5,3),(2)->(7,2,3)",
(11, 7, 5, 3),
(2,),
False,
),
(
(lambda x, idx: x[idx, None]),
"(5,7,3),(2)->(2,1,7,3)",
(11, 5, 7, 3),
(2,),
False,
),
# (this is currently failing because PyTensor tries to vectorize the slice(None) operation,
# due to the exact same None constant being used there and in the np.newaxis)
pytest.param(
(lambda x, idx: x[:, idx, None]),
"(7,5,3),(2)->(7,2,1,3)",
(11, 7, 5, 3),
(2,),
False,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
(
(lambda x, idx: x[:, idx, idx, :]),
"(7,5,5,3),(2)->(7,2,3)",
(11, 7, 5, 5, 3),
(2,),
False,
),
# (not supported, because fallback Blocwise can't handle slices)
pytest.param(
(lambda x, idx: x[:, idx, :, idx]),
"(7,5,3,5),(2)->(2,7,3)",
(11, 7, 5, 3, 5),
(2,),
True,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
# Core x, batched idx
((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (7,), (11, 2), True),
# Batched x, batched idx
((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (11, 7), (11, 2), True),
# (not supported, because fallback Blocwise can't handle slices)
pytest.param(
(lambda x, idx: x[:, idx, :]),
"(t1,t2,t3),(idx)->(t1,tx,t3)",
(11, 7, 5, 3),
(11, 2),
True,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
],
)
def test_vectorize_adv_subtensor(
core_idx_fn, signature, x_shape, idx_shape, uses_blockwise
):
x = tensor(shape=x_shape, dtype="float64")
idx = tensor(shape=idx_shape, dtype="int64")
vectorize_pt = function(
[x, idx], vectorize(core_idx_fn, signature=signature)(x, idx)
)
has_blockwise = any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
assert has_blockwise == uses_blockwise
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
# Idx dimension should be length 5
idx_test = np.random.randint(0, 5, size=idx.type.shape)
vectorize_np = np.vectorize(core_idx_fn, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, idx_test),
vectorize_np(x_test, idx_test),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论