提交 0fd160b6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement static_shape inference for AdvancedSubtensor

上级 f7cc0f07
......@@ -2,7 +2,7 @@ import logging
import sys
import warnings
from collections.abc import Callable, Iterable, Sequence
from itertools import chain, groupby
from itertools import chain, groupby, zip_longest
from typing import cast, overload
import numpy as np
......@@ -39,7 +39,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip
from pytensor.tensor.math import add, clip
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
from pytensor.tensor.type import (
TensorType,
......@@ -63,6 +63,7 @@ from pytensor.tensor.type import (
from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneSliceConst,
NoneTypeT,
SliceConstant,
SliceType,
......@@ -844,6 +845,24 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
return ps.as_scalar(a)
def slice_static_length(slc, dim_length):
if dim_length is None:
# TODO: Some cases must be zero by definition, we could handle those
return None
entries = [None, None, None]
for i, entry in enumerate((slc.start, slc.stop, slc.step)):
if entry is None:
continue
try:
entries[i] = get_scalar_constant_value(entry)
except NotScalarConstantError:
return None
return len(range(*slice(*entries).indices(dim_length)))
class Subtensor(COp):
"""Basic NumPy indexing operator."""
......@@ -886,50 +905,15 @@ class Subtensor(COp):
)
padded = [
*get_idx_list((None, *inputs), self.idx_list),
*indices_from_subtensor(inputs, self.idx_list),
*[slice(None, None, None)] * (x.type.ndim - len(idx_list)),
]
out_shape = []
def extract_const(value):
if value is None:
return value, True
try:
value = get_scalar_constant_value(value)
return value, True
except NotScalarConstantError:
return value, False
for the_slice, length in zip(padded, x.type.shape, strict=True):
if not isinstance(the_slice, slice):
continue
if length is None:
out_shape.append(None)
continue
start = the_slice.start
stop = the_slice.stop
step = the_slice.step
is_slice_const = True
start, is_const = extract_const(start)
is_slice_const = is_slice_const and is_const
stop, is_const = extract_const(stop)
is_slice_const = is_slice_const and is_const
step, is_const = extract_const(step)
is_slice_const = is_slice_const and is_const
if not is_slice_const:
out_shape.append(None)
continue
slice_length = len(range(*slice(start, stop, step).indices(length)))
out_shape.append(slice_length)
out_shape = [
slice_static_length(slc, length)
for slc, length in zip(padded, x.type.shape, strict=True)
if isinstance(slc, slice)
]
return Apply(
self,
......@@ -2826,36 +2810,112 @@ class AdvancedSubtensor(Op):
__props__ = ()
def make_node(self, x, *index):
def make_node(self, x, *indices):
x = as_tensor_variable(x)
index = tuple(map(as_index_variable, index))
indices = tuple(map(as_index_variable, indices))
explicit_indices = []
new_axes = []
for idx in indices:
if isinstance(idx.type, TensorType) and idx.dtype == "bool":
if idx.type.ndim == 0:
raise NotImplementedError(
"Indexing with scalar booleans not supported"
)
# We create a fake symbolic shape tuple and identify the broadcast
# dimensions from the shape result of this entire subtensor operation.
with config.change_flags(compute_test_value="off"):
fake_shape = tuple(
tensor(dtype="int64", shape=()) if s != 1 else 1 for s in x.type.shape
)
# Check static shape aligned
axis = len(explicit_indices) - len(new_axes)
indexed_shape = x.type.shape[axis : axis + idx.type.ndim]
for j, (indexed_length, indexer_length) in enumerate(
zip(indexed_shape, idx.type.shape)
):
if (
indexed_length is not None
and indexer_length is not None
and indexed_length != indexer_length
):
raise IndexError(
f"boolean index did not match indexed tensor along axis {axis + j};"
f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}"
)
# Convert boolean indices to integer with nonzero, to reason about static shape next
if isinstance(idx, Constant):
nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()]
else:
# Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
# and seeing that other integer indices cannot possible match it
nonzero_indices = idx.nonzero()
explicit_indices.extend(nonzero_indices)
else:
if isinstance(idx.type, NoneTypeT):
new_axes.append(len(explicit_indices))
explicit_indices.append(idx)
fake_index = tuple(
chain.from_iterable(
pytensor.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0
and getattr(idx, "dtype", None) == "bool"
else (idx,)
for idx in index
)
if (len(explicit_indices) - len(new_axes)) > x.type.ndim:
raise IndexError(
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed"
)
out_shape = tuple(
i.value if isinstance(i, Constant) else None
for i in indexed_result_shape(fake_shape, fake_index)
)
# Perform basic and advanced indexing shape inference separately
basic_group_shape = []
advanced_indices = []
adv_group_axis = None
last_adv_group_axis = None
expanded_x_shape = tuple(
np.insert(np.array(x.type.shape, dtype=object), 1, new_axes)
)
for i, (idx, dim_length) in enumerate(
zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)
):
if isinstance(idx.type, NoneTypeT):
basic_group_shape.append(1) # New-axis
elif isinstance(idx.type, SliceType):
if isinstance(idx, Constant):
basic_group_shape.append(slice_static_length(idx.data, dim_length))
elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice):
basic_group_shape.append(
slice_static_length(slice(*idx.owner.inputs), dim_length)
)
else:
# Symbolic root slice (owner is None), or slice operation we don't understand
basic_group_shape.append(None)
else: # TensorType
# Keep track of advanced group axis
if adv_group_axis is None:
# First time we see an advanced index
adv_group_axis, last_adv_group_axis = i, i
elif last_adv_group_axis == (i - 1):
# Another advanced indexing aligned with the first group
last_adv_group_axis = i
else:
# Non-consecutive advanced index, all advanced index views get moved to the front
adv_group_axis = 0
advanced_indices.append(idx)
if advanced_indices:
try:
# Use variadic add to infer static shape of advanced integer indices
advanced_group_static_shape = add(*advanced_indices).type.shape
except ValueError:
# It fails when static shapes are inconsistent
static_shapes = [idx.type.shape for idx in advanced_indices]
raise IndexError(
f"shape mismatch: indexing tensors could not be broadcast together with shapes {static_shapes}"
)
# Combine advanced and basic views
indexed_shape = [
*basic_group_shape[:adv_group_axis],
*advanced_group_static_shape,
*basic_group_shape[adv_group_axis:],
]
else:
# This could have been a basic subtensor!
indexed_shape = basic_group_shape
return Apply(
self,
(x, *index),
[tensor(dtype=x.type.dtype, shape=out_shape)],
[x, *indices],
[tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))],
)
def R_op(self, inputs, eval_points):
......
......@@ -114,6 +114,9 @@ def as_symbolic_slice(x, **kwargs):
return SliceConstant(slicetype, x)
NoneSliceConst = Constant(slicetype, slice(None), name="slice(None)")
class NoneTypeT(Generic):
"""
Inherit from Generic to have c code working.
......@@ -137,4 +140,4 @@ def as_symbolic_None(x, **kwargs):
return NoneConst
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"]
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst", "NoneSliceConst"]
......@@ -506,7 +506,9 @@ class _tensor_py_operators:
# Check if the number of dimensions isn't too large.
if self.ndim < index_dim_count:
raise IndexError("too many indices for array")
raise IndexError(
f"too many indices for tensor: tensor is {self.ndim}-dimensional, but {index_dim_count} were indexed"
)
# Convert an Ellipsis if provided into an appropriate number of
# slice(None).
......
import logging
import re
import sys
from io import StringIO
......@@ -1847,6 +1848,95 @@ class TestAdvancedSubtensor:
self.ix2 = lmatrix()
self.ixr = lrow()
def test_static_shape(self):
x = tensor("x", shape=(None, None))
y = tensor("y", shape=(4, 5, 6))
idx1 = tensor("idx1", shape=(10,), dtype=int)
idx2 = tensor("idx2", shape=(3, None), dtype=int)
assert x[idx1].type.shape == (10, None)
assert x[:, idx1].type.shape == (None, 10)
assert x[idx2, :5].type.shape == (3, None, None)
assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5)
assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3)
assert x[idx1, idx2].type.shape == (3, 10)
assert x[idx2, idx1].type.shape == (3, 10)
assert x[None, idx1, idx2].type.shape == (1, 3, 10)
assert x[idx1, None, idx2].type.shape == (3, 10, 1)
assert x[idx1, idx2, None].type.shape == (3, 10, 1)
assert y[idx1, idx2, ::-1].type.shape == (3, 10, 6)
assert y[idx1, ::-1, idx2].type.shape == (3, 10, 5)
assert y[::-1, idx1, idx2].type.shape == (4, 3, 10)
assert y[::-1, idx1, None, idx2].type.shape == (3, 10, 4, 1)
msg = re.escape(
"shape mismatch: indexing tensors could not be broadcast together with shapes [(10,), (9,)]"
)
with pytest.raises(IndexError, match=msg):
x[idx1, idx1[1:]]
def test_static_shape_boolean(self):
y = tensor("y", shape=(4, 5, 6))
idx1 = tensor("idx1", shape=(4,), dtype=int)
idx2 = tensor("idx2", shape=(3, None), dtype=int)
bool_idx1 = tensor("bool_idx1", shape=(4,), dtype=bool)
bool_idx2 = tensor(
"bool_idx2",
shape=(
None,
5,
),
dtype=bool,
)
assert y[bool_idx1].type.shape == (None, 5, 6)
assert y[bool_idx1, :, None:-4:-1].type.shape == (None, 5, 3)
assert y[bool_idx1, idx2].type.shape == (3, None, 6)
assert y[bool_idx1, idx1, :].type.shape == (4, 6)
assert y[bool_idx1, :, idx1].type.shape == (4, 5)
assert y[bool_idx1, idx1, idx2].type.shape == (3, 4)
assert y[None, bool_idx1, None, idx2, None, idx1].type.shape == (3, 4, 1, 1, 1)
assert y[bool_idx2, :].type.shape == (None, 6)
assert y[bool_idx2, idx1].type.shape == (4,)
assert y[bool_idx2, idx2].type.shape == (3, None)
msg = re.escape(
"too many indices for tensor: tensor is 3-dimensional, but 4 were indexed"
)
with pytest.raises(IndexError, match=msg):
y[bool_idx2, bool_idx2]
# Case that could conceivably be detected as index error at definition time
bad_idx = ptb.concatenate([idx1, idx1])
assert y[bool_idx1, bad_idx].type.shape == (8, 6)
def test_static_shape_constant_boolean(self):
y = tensor("y", shape=(None, None, None))
idx1 = tensor("idx1", shape=(3,), dtype=int)
idx2 = tensor("idx2", shape=(4, None), dtype=int)
bool_idx1 = constant(np.array([True, False, True, True]), name="bool_idx1")
bool_idx2 = constant(
np.array([[True, False, True, True], [True, False, False, True]]),
name="bool_idx2",
)
assert y[bool_idx1].type.shape == (3, None, None)
assert y[bool_idx1, :, idx1].type.shape == (3, None)
assert y[bool_idx1, :, idx2].type.shape == (4, 3, None)
assert y[bool_idx2].type.shape == (5, None)
assert y[bool_idx1, idx2].type.shape == (4, 3, None)
bad_idx = ptb.concatenate([idx1, idx1])
msg = re.escape(
"shape mismatch: indexing tensors could not be broadcast together with shapes [(3,), (6,)]"
)
with pytest.raises(IndexError, match=msg):
y[bool_idx1, bad_idx]
@pytest.mark.parametrize(
"inplace",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论