提交 9df54a54 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement shape inference for boolean advanced indexing

上级 f6407da2
...@@ -21,14 +21,18 @@ from pytensor.misc.safe_asarray import _asarray ...@@ -21,14 +21,18 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as at_abs from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod from pytensor.tensor.math import ge, lt
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import maximum, minimum, prod
from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import switch
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.var import TensorVariable from pytensor.tensor.var import TensorVariable
...@@ -1063,7 +1067,7 @@ class FillDiagonalOffset(Op): ...@@ -1063,7 +1067,7 @@ class FillDiagonalOffset(Op):
# only valid for matrices # only valid for matrices
wr_a = fill_diagonal_offset(grad, 0, offset) wr_a = fill_diagonal_offset(grad, 0, offset)
offset_abs = at_abs(offset) offset_abs = pt_abs(offset)
pos_offset_flag = ge(offset, 0) pos_offset_flag = ge(offset, 0)
neg_offset_flag = lt(offset, 0) neg_offset_flag = lt(offset, 0)
min_wh = minimum(width, height) min_wh = minimum(width, height)
...@@ -1442,6 +1446,7 @@ _broadcast_assert = Assert( ...@@ -1442,6 +1446,7 @@ _broadcast_assert = Assert(
"axes that have a statically known length 1. Use `specify_broadcastable` to " "axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape." "inform PyTensor of a known shape."
) )
_runtime_broadcast_assert = Assert("Could not broadcast dimensions.")
def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]: def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
...@@ -1465,6 +1470,7 @@ def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]: ...@@ -1465,6 +1470,7 @@ def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
def broadcast_shape_iter( def broadcast_shape_iter(
arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]], arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]],
arrays_are_shapes: bool = False, arrays_are_shapes: bool = False,
allow_runtime_broadcast: bool = False,
) -> Tuple[aes.ScalarVariable, ...]: ) -> Tuple[aes.ScalarVariable, ...]:
r"""Compute the shape resulting from broadcasting arrays. r"""Compute the shape resulting from broadcasting arrays.
...@@ -1480,22 +1486,24 @@ def broadcast_shape_iter( ...@@ -1480,22 +1486,24 @@ def broadcast_shape_iter(
arrays arrays
An iterable of tensors, or a tuple of shapes (as tuples), An iterable of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed. for which the broadcast shape is computed.
arrays_are_shapes arrays_are_shapes: bool, default False
Indicates whether or not the `arrays` contains shape tuples. Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1``--or simply the integer are (scalar) constants with the value ``1``--or simply the integer
``1``. ``1``. This is not revelant if `allow_runtime_broadcast` is True.
allow_runtime_broadcast: bool, default False
Whether to allow non-statically known broadcast on the shape computation.
""" """
one_at = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1) one = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1)
if arrays_are_shapes: if arrays_are_shapes:
max_dims = max(len(a) for a in arrays) max_dims = max(len(a) for a in arrays)
array_shapes = [ array_shapes = [
(one_at,) * (max_dims - len(a)) (one,) * (max_dims - len(a))
+ tuple( + tuple(
one_at one
if sh == 1 or isinstance(sh, Constant) and sh.value == 1 if sh == 1 or isinstance(sh, Constant) and sh.value == 1
else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh) else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a for sh in a
...@@ -1508,10 +1516,8 @@ def broadcast_shape_iter( ...@@ -1508,10 +1516,8 @@ def broadcast_shape_iter(
_arrays = tuple(at.as_tensor_variable(a) for a in arrays) _arrays = tuple(at.as_tensor_variable(a) for a in arrays)
array_shapes = [ array_shapes = [
(one_at,) * (max_dims - a.ndim) (one,) * (max_dims - a.ndim)
+ tuple( + tuple(one if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape))
one_at if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape)
)
for a in _arrays for a in _arrays
] ]
...@@ -1520,11 +1526,11 @@ def broadcast_shape_iter( ...@@ -1520,11 +1526,11 @@ def broadcast_shape_iter(
for dim_shapes in zip(*array_shapes): for dim_shapes in zip(*array_shapes):
# Get the shapes in this dimension that are not broadcastable # Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable) # (i.e. not symbolically known to be broadcastable)
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] non_bcast_shapes = [shape for shape in dim_shapes if shape != one]
if len(non_bcast_shapes) == 0: if len(non_bcast_shapes) == 0:
# Every shape was broadcastable in this dimension # Every shape was broadcastable in this dimension
result_dims.append(one_at) result_dims.append(one)
elif len(non_bcast_shapes) == 1: elif len(non_bcast_shapes) == 1:
# Only one shape might not be broadcastable in this dimension # Only one shape might not be broadcastable in this dimension
result_dims.extend(non_bcast_shapes) result_dims.extend(non_bcast_shapes)
...@@ -1554,9 +1560,26 @@ def broadcast_shape_iter( ...@@ -1554,9 +1560,26 @@ def broadcast_shape_iter(
result_dims.append(first_length) result_dims.append(first_length)
continue continue
if not allow_runtime_broadcast:
# Add assert that all remaining shapes are equal # Add assert that all remaining shapes are equal
condition = pt_all([pt_eq(first_length, other) for other in other_lengths]) condition = pt_all(
[pt_eq(first_length, other) for other in other_lengths]
)
result_dims.append(_broadcast_assert(first_length, condition)) result_dims.append(_broadcast_assert(first_length, condition))
else:
lengths = as_tensor_variable((first_length, *other_lengths))
runtime_broadcastable = pt_eq(lengths, one)
result_dim = pt_abs(
pt_max(switch(runtime_broadcastable, -one, lengths))
)
condition = pt_all(
switch(
~runtime_broadcastable,
pt_eq(lengths, result_dim),
np.array(True),
)
)
result_dims.append(_runtime_broadcast_assert(result_dim, condition))
return tuple(result_dims) return tuple(result_dims)
......
...@@ -20,15 +20,11 @@ from pytensor.misc.safe_asarray import _asarray ...@@ -20,15 +20,11 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint, set_precedence from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import ( from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
AdvancedIndexingError,
NotScalarConstantError,
ShapeError,
)
from pytensor.tensor.math import clip from pytensor.tensor.math import clip
from pytensor.tensor.shape import Reshape, specify_broadcastable from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
bscalar, bscalar,
...@@ -510,7 +506,11 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ...@@ -510,7 +506,11 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
from pytensor.tensor.extra_ops import broadcast_shape from pytensor.tensor.extra_ops import broadcast_shape
res_shape += broadcast_shape( res_shape += broadcast_shape(
*grp_indices, arrays_are_shapes=indices_are_shapes *grp_indices,
arrays_are_shapes=indices_are_shapes,
# The AdvancedIndexing Op relies on the Numpy implementation which allows runtime broadcasting.
# As long as that is true, the shape inference has to respect that this is not an error.
allow_runtime_broadcast=True,
) )
res_shape += tuple(array_shape[dim] for dim in remaining_dims) res_shape += tuple(array_shape[dim] for dim in remaining_dims)
...@@ -2584,26 +2584,47 @@ class AdvancedSubtensor(Op): ...@@ -2584,26 +2584,47 @@ class AdvancedSubtensor(Op):
return self.make_node(eval_points[0], *inputs[1:]).outputs return self.make_node(eval_points[0], *inputs[1:]).outputs
def infer_shape(self, fgraph, node, ishapes): def infer_shape(self, fgraph, node, ishapes):
indices = node.inputs[1:] def is_bool_index(idx):
index_shapes = list(ishapes[1:]) return (
for i, idx in enumerate(indices):
if (
isinstance(idx, (np.bool_, bool)) isinstance(idx, (np.bool_, bool))
or getattr(idx, "dtype", None) == "bool" or getattr(idx, "dtype", None) == "bool"
): )
raise ShapeError(
"Shape inference for boolean indices is not implemented" indices = node.inputs[1:]
index_shapes = []
for idx, ishape in zip(indices, ishapes[1:]):
# Mixed bool indexes are converted to nonzero entries
if is_bool_index(idx):
index_shapes.extend(
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
) )
# The `ishapes` entries for `SliceType`s will be None, and # The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices. # we need to give `indexed_result_shape` the actual slices.
if isinstance(getattr(idx, "type", None), SliceType): elif isinstance(getattr(idx, "type", None), SliceType):
index_shapes[i] = idx index_shapes.append(idx)
else:
index_shapes.append(ishape)
res_shape = indexed_result_shape( res_shape = list(
ishapes[0], index_shapes, indices_are_shapes=True indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
) )
adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
# Special logic when the only advanced index group is of bool type.
# We can replace the nonzeros by a sum of the whole bool variable.
if len(bool_indices) == 1 and len(adv_indices) == 1:
[bool_index] = bool_indices
# Find the output dim associated with the bool index group
# Because there are no more advanced index groups, there is exactly
# one output dim per index variable up to the bool group.
# Note: Scalar integer indexing counts as advanced indexing.
start_dim = indices.index(bool_index)
res_shape[start_dim] = bool_index.sum()
assert node.outputs[0].ndim == len(res_shape) assert node.outputs[0].ndim == len(res_shape)
return [list(res_shape)] return [res_shape]
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
......
...@@ -1087,9 +1087,17 @@ def test_broadcast_shape_basic(): ...@@ -1087,9 +1087,17 @@ def test_broadcast_shape_basic():
assert any( assert any(
isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at) isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at)
) )
# This should fail because it would need dynamic broadcasting
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert np.array_equal([z.eval() for z in b_at], b.shape) assert np.array_equal([z.eval() for z in b_at], b.shape)
# But fine if we allow_runtime_broadcast
b_at = broadcast_shape(
shape_tuple(x_at, use_bcast=False),
shape_tuple(y_at, use_bcast=False),
arrays_are_shapes=True,
allow_runtime_broadcast=True,
)
assert np.array_equal([z.eval() for z in b_at], b.shape)
# Or if static bcast is known
b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True) b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_at], b.shape) assert np.array_equal([z.eval() for z in b_at], b.shape)
......
...@@ -63,6 +63,7 @@ from pytensor.tensor.type import ( ...@@ -63,6 +63,7 @@ from pytensor.tensor.type import (
tensor, tensor,
tensor3, tensor3,
tensor4, tensor4,
tensor5,
vector, vector,
) )
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
...@@ -2150,6 +2151,12 @@ class TestAdvancedSubtensor: ...@@ -2150,6 +2151,12 @@ class TestAdvancedSubtensor:
class TestInferShape(utt.InferShapeTester): class TestInferShape(utt.InferShapeTester):
@staticmethod
def random_bool_mask(shape, rng=None):
if rng is None:
rng = np.random.default_rng()
return rng.binomial(n=1, p=0.5, size=shape).astype(bool)
def test_IncSubtensor(self): def test_IncSubtensor(self):
admat = dmatrix() admat = dmatrix()
bdmat = dmatrix() bdmat = dmatrix()
...@@ -2439,25 +2446,85 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2439,25 +2446,85 @@ class TestInferShape(utt.InferShapeTester):
n = dmatrix() n = dmatrix()
n_val = np.arange(6).reshape((2, 3)) n_val = np.arange(6).reshape((2, 3))
# infer_shape is not implemented, but it should not crash # Shape inference requires runtime broadcasting between the nonzero() shapes
self._compile_and_check( self._compile_and_check(
[n], [n],
[n[n[:, 0] > 2, n[0, :] > 2]], [n[n[:, 0] > 2, n[0, :] > 2]],
[n_val], [n_val],
AdvancedSubtensor, AdvancedSubtensor,
check_topo=False,
) )
self._compile_and_check( self._compile_and_check(
[n], [n],
[n[n[:, 0] > 2]], [n[n[:, 0] > 2]],
[n_val], [n_val],
AdvancedSubtensor, AdvancedSubtensor,
check_topo=False, )
self._compile_and_check(
[n],
[n[:, np.array([True, False, True])]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[np.array([False, False]), 1:]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[np.array([True, True]), 0]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[self.random_bool_mask(n_val.shape)]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[None, self.random_bool_mask(n_val.shape), None]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
[n[slice(5, None), self.random_bool_mask(n_val.shape[1])]],
[n_val],
AdvancedSubtensor,
) )
abs_res = n[~isinf(n)] abs_res = n[~isinf(n)]
assert abs_res.type.shape == (None,) assert abs_res.type.shape == (None,)
def test_AdvancedSubtensor_bool_mixed(self):
n = tensor5("x", dtype="float64")
shape = (18, 3, 4, 5, 6)
n_val = np.arange(np.prod(shape)).reshape(shape)
self._compile_and_check(
[n],
# Consecutive advanced index
[n[1:, self.random_bool_mask((3, 4)), 0, 1:]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
# Non-consecutive advanced index
[n[1:, self.random_bool_mask((3, 4)), 1:, 0]],
[n_val],
AdvancedSubtensor,
)
self._compile_and_check(
[n],
# Non-consecutive advanced index
[n[1:, self.random_bool_mask((3,)), 1:, None, np.zeros((6, 1), dtype=int)]],
[n_val],
AdvancedSubtensor,
)
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
def test_basic_shape(): def test_basic_shape():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论