提交 9df54a54 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement shape inference for boolean advanced indexing

上级 f6407da2
......@@ -21,14 +21,18 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod
from pytensor.tensor.math import ge, lt
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import maximum, minimum, prod
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import switch
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.var import TensorVariable
......@@ -1063,7 +1067,7 @@ class FillDiagonalOffset(Op):
# only valid for matrices
wr_a = fill_diagonal_offset(grad, 0, offset)
offset_abs = at_abs(offset)
offset_abs = pt_abs(offset)
pos_offset_flag = ge(offset, 0)
neg_offset_flag = lt(offset, 0)
min_wh = minimum(width, height)
......@@ -1442,6 +1446,7 @@ _broadcast_assert = Assert(
"axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape."
)
_runtime_broadcast_assert = Assert("Could not broadcast dimensions.")
def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
......@@ -1465,6 +1470,7 @@ def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
def broadcast_shape_iter(
arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]],
arrays_are_shapes: bool = False,
allow_runtime_broadcast: bool = False,
) -> Tuple[aes.ScalarVariable, ...]:
r"""Compute the shape resulting from broadcasting arrays.
......@@ -1480,22 +1486,24 @@ def broadcast_shape_iter(
arrays
An iterable of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed.
arrays_are_shapes
arrays_are_shapes: bool, default False
Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1``--or simply the integer
``1``.
``1``. This is not revelant if `allow_runtime_broadcast` is True.
allow_runtime_broadcast: bool, default False
Whether to allow non-statically known broadcast on the shape computation.
"""
one_at = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1)
one = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1)
if arrays_are_shapes:
max_dims = max(len(a) for a in arrays)
array_shapes = [
(one_at,) * (max_dims - len(a))
(one,) * (max_dims - len(a))
+ tuple(
one_at
one
if sh == 1 or isinstance(sh, Constant) and sh.value == 1
else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a
......@@ -1508,10 +1516,8 @@ def broadcast_shape_iter(
_arrays = tuple(at.as_tensor_variable(a) for a in arrays)
array_shapes = [
(one_at,) * (max_dims - a.ndim)
+ tuple(
one_at if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape)
)
(one,) * (max_dims - a.ndim)
+ tuple(one if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape))
for a in _arrays
]
......@@ -1520,11 +1526,11 @@ def broadcast_shape_iter(
for dim_shapes in zip(*array_shapes):
# Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable)
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
non_bcast_shapes = [shape for shape in dim_shapes if shape != one]
if len(non_bcast_shapes) == 0:
# Every shape was broadcastable in this dimension
result_dims.append(one_at)
result_dims.append(one)
elif len(non_bcast_shapes) == 1:
# Only one shape might not be broadcastable in this dimension
result_dims.extend(non_bcast_shapes)
......@@ -1554,9 +1560,26 @@ def broadcast_shape_iter(
result_dims.append(first_length)
continue
# Add assert that all remaining shapes are equal
condition = pt_all([pt_eq(first_length, other) for other in other_lengths])
result_dims.append(_broadcast_assert(first_length, condition))
if not allow_runtime_broadcast:
# Add assert that all remaining shapes are equal
condition = pt_all(
[pt_eq(first_length, other) for other in other_lengths]
)
result_dims.append(_broadcast_assert(first_length, condition))
else:
lengths = as_tensor_variable((first_length, *other_lengths))
runtime_broadcastable = pt_eq(lengths, one)
result_dim = pt_abs(
pt_max(switch(runtime_broadcastable, -one, lengths))
)
condition = pt_all(
switch(
~runtime_broadcastable,
pt_eq(lengths, result_dim),
np.array(True),
)
)
result_dims.append(_runtime_broadcast_assert(result_dim, condition))
return tuple(result_dims)
......
......@@ -20,15 +20,11 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import (
AdvancedIndexingError,
NotScalarConstantError,
ShapeError,
)
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip
from pytensor.tensor.shape import Reshape, specify_broadcastable
from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable
from pytensor.tensor.type import (
TensorType,
bscalar,
......@@ -510,7 +506,11 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
from pytensor.tensor.extra_ops import broadcast_shape
res_shape += broadcast_shape(
*grp_indices, arrays_are_shapes=indices_are_shapes
*grp_indices,
arrays_are_shapes=indices_are_shapes,
# The AdvancedIndexing Op relies on the Numpy implementation which allows runtime broadcasting.
# As long as that is true, the shape inference has to respect that this is not an error.
allow_runtime_broadcast=True,
)
res_shape += tuple(array_shape[dim] for dim in remaining_dims)
......@@ -2584,26 +2584,47 @@ class AdvancedSubtensor(Op):
return self.make_node(eval_points[0], *inputs[1:]).outputs
def infer_shape(self, fgraph, node, ishapes):
indices = node.inputs[1:]
index_shapes = list(ishapes[1:])
for i, idx in enumerate(indices):
if (
def is_bool_index(idx):
return (
isinstance(idx, (np.bool_, bool))
or getattr(idx, "dtype", None) == "bool"
):
raise ShapeError(
"Shape inference for boolean indices is not implemented"
)
indices = node.inputs[1:]
index_shapes = []
for idx, ishape in zip(indices, ishapes[1:]):
# Mixed bool indexes are converted to nonzero entries
if is_bool_index(idx):
index_shapes.extend(
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
)
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
if isinstance(getattr(idx, "type", None), SliceType):
index_shapes[i] = idx
elif isinstance(getattr(idx, "type", None), SliceType):
index_shapes.append(idx)
else:
index_shapes.append(ishape)
res_shape = indexed_result_shape(
ishapes[0], index_shapes, indices_are_shapes=True
res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
)
adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
# Special logic when the only advanced index group is of bool type.
# We can replace the nonzeros by a sum of the whole bool variable.
if len(bool_indices) == 1 and len(adv_indices) == 1:
[bool_index] = bool_indices
# Find the output dim associated with the bool index group
# Because there are no more advanced index groups, there is exactly
# one output dim per index variable up to the bool group.
# Note: Scalar integer indexing counts as advanced indexing.
start_dim = indices.index(bool_index)
res_shape[start_dim] = bool_index.sum()
assert node.outputs[0].ndim == len(res_shape)
return [list(res_shape)]
return [res_shape]
def perform(self, node, inputs, out_):
(out,) = out_
......
......@@ -1087,9 +1087,17 @@ def test_broadcast_shape_basic():
assert any(
isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at)
)
# This should fail because it would need dynamic broadcasting
with pytest.raises(AssertionError):
assert np.array_equal([z.eval() for z in b_at], b.shape)
# But fine if we allow_runtime_broadcast
b_at = broadcast_shape(
shape_tuple(x_at, use_bcast=False),
shape_tuple(y_at, use_bcast=False),
arrays_are_shapes=True,
allow_runtime_broadcast=True,
)
assert np.array_equal([z.eval() for z in b_at], b.shape)
# Or if static bcast is known
b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_at], b.shape)
......
......@@ -63,6 +63,7 @@ from pytensor.tensor.type import (
tensor,
tensor3,
tensor4,
tensor5,
vector,
)
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
......@@ -2150,6 +2151,12 @@ class TestAdvancedSubtensor:
class TestInferShape(utt.InferShapeTester):
@staticmethod
def random_bool_mask(shape, rng=None):
if rng is None:
rng = np.random.default_rng()
return rng.binomial(n=1, p=0.5, size=shape).astype(bool)
def test_IncSubtensor(self):
admat = dmatrix()
bdmat = dmatrix()
......@@ -2439,25 +2446,85 @@ class TestInferShape(utt.InferShapeTester):
n = dmatrix()
n_val = np.arange(6).reshape((2, 3))
# infer_shape is not implemented, but it should not crash
# Shape inference requires runtime broadcasting between the nonzero() shapes
self._compile_and_check(
[n],
[n[n[:, 0] > 2, n[0, :] > 2]],
[n_val],
AdvancedSubtensor,
check_topo=False,
)
self._compile_and_check(
[n],
[n[n[:, 0] > 2]],
[n_val],
AdvancedSubtensor,
check_topo=False,
)
self._compile_and_check(
[n],
[n[:, np.array([True, False, True])]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[np.array([False, False]), 1:]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[np.array([True, True]), 0]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[self.random_bool_mask(n_val.shape)]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[None, self.random_bool_mask(n_val.shape), None]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[slice(5, None), self.random_bool_mask(n_val.shape[1])]],
[n_val],
AdvancedSubtensor,
)
abs_res = n[~isinf(n)]
assert abs_res.type.shape == (None,)
def test_AdvancedSubtensor_bool_mixed(self):
n = tensor5("x", dtype="float64")
shape = (18, 3, 4, 5, 6)
n_val = np.arange(np.prod(shape)).reshape(shape)
self._compile_and_check(
[n],
# Consecutive advanced index
[n[1:, self.random_bool_mask((3, 4)), 0, 1:]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
# Non-consecutive advanced index
[n[1:, self.random_bool_mask((3, 4)), 1:, 0]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
# Non-consecutive advanced index
[n[1:, self.random_bool_mask((3,)), 1:, None, np.zeros((6, 1), dtype=int)]],
[n_val],
AdvancedSubtensor,
)
@config.change_flags(compute_test_value="raise")
def test_basic_shape():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论