提交 0fd160b6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement static_shape inference for AdvancedSubtensor

上级 f7cc0f07
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import sys import sys
import warnings import warnings
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from itertools import chain, groupby from itertools import chain, groupby, zip_longest
from typing import cast, overload from typing import cast, overload
import numpy as np import numpy as np
...@@ -39,7 +39,7 @@ from pytensor.tensor.basic import ( ...@@ -39,7 +39,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip from pytensor.tensor.math import add, clip
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
...@@ -63,6 +63,7 @@ from pytensor.tensor.type import ( ...@@ -63,6 +63,7 @@ from pytensor.tensor.type import (
from pytensor.tensor.type_other import ( from pytensor.tensor.type_other import (
MakeSlice, MakeSlice,
NoneConst, NoneConst,
NoneSliceConst,
NoneTypeT, NoneTypeT,
SliceConstant, SliceConstant,
SliceType, SliceType,
...@@ -844,6 +845,24 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable: ...@@ -844,6 +845,24 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
return ps.as_scalar(a) return ps.as_scalar(a)
def slice_static_length(slc, dim_length):
if dim_length is None:
# TODO: Some cases must be zero by definition, we could handle those
return None
entries = [None, None, None]
for i, entry in enumerate((slc.start, slc.stop, slc.step)):
if entry is None:
continue
try:
entries[i] = get_scalar_constant_value(entry)
except NotScalarConstantError:
return None
return len(range(*slice(*entries).indices(dim_length)))
class Subtensor(COp): class Subtensor(COp):
"""Basic NumPy indexing operator.""" """Basic NumPy indexing operator."""
...@@ -886,50 +905,15 @@ class Subtensor(COp): ...@@ -886,50 +905,15 @@ class Subtensor(COp):
) )
padded = [ padded = [
*get_idx_list((None, *inputs), self.idx_list), *indices_from_subtensor(inputs, self.idx_list),
*[slice(None, None, None)] * (x.type.ndim - len(idx_list)), *[slice(None, None, None)] * (x.type.ndim - len(idx_list)),
] ]
out_shape = [] out_shape = [
slice_static_length(slc, length)
def extract_const(value): for slc, length in zip(padded, x.type.shape, strict=True)
if value is None: if isinstance(slc, slice)
return value, True ]
try:
value = get_scalar_constant_value(value)
return value, True
except NotScalarConstantError:
return value, False
for the_slice, length in zip(padded, x.type.shape, strict=True):
if not isinstance(the_slice, slice):
continue
if length is None:
out_shape.append(None)
continue
start = the_slice.start
stop = the_slice.stop
step = the_slice.step
is_slice_const = True
start, is_const = extract_const(start)
is_slice_const = is_slice_const and is_const
stop, is_const = extract_const(stop)
is_slice_const = is_slice_const and is_const
step, is_const = extract_const(step)
is_slice_const = is_slice_const and is_const
if not is_slice_const:
out_shape.append(None)
continue
slice_length = len(range(*slice(start, stop, step).indices(length)))
out_shape.append(slice_length)
return Apply( return Apply(
self, self,
...@@ -2826,36 +2810,112 @@ class AdvancedSubtensor(Op): ...@@ -2826,36 +2810,112 @@ class AdvancedSubtensor(Op):
__props__ = () __props__ = ()
def make_node(self, x, *index): def make_node(self, x, *indices):
x = as_tensor_variable(x) x = as_tensor_variable(x)
index = tuple(map(as_index_variable, index)) indices = tuple(map(as_index_variable, indices))
explicit_indices = []
new_axes = []
for idx in indices:
if isinstance(idx.type, TensorType) and idx.dtype == "bool":
if idx.type.ndim == 0:
raise NotImplementedError(
"Indexing with scalar booleans not supported"
)
# We create a fake symbolic shape tuple and identify the broadcast # Check static shape aligned
# dimensions from the shape result of this entire subtensor operation. axis = len(explicit_indices) - len(new_axes)
with config.change_flags(compute_test_value="off"): indexed_shape = x.type.shape[axis : axis + idx.type.ndim]
fake_shape = tuple( for j, (indexed_length, indexer_length) in enumerate(
tensor(dtype="int64", shape=()) if s != 1 else 1 for s in x.type.shape zip(indexed_shape, idx.type.shape)
) ):
if (
indexed_length is not None
and indexer_length is not None
and indexed_length != indexer_length
):
raise IndexError(
f"boolean index did not match indexed tensor along axis {axis + j};"
f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}"
)
# Convert boolean indices to integer with nonzero, to reason about static shape next
if isinstance(idx, Constant):
nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()]
else:
# Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
# and seeing that other integer indices cannot possible match it
nonzero_indices = idx.nonzero()
explicit_indices.extend(nonzero_indices)
else:
if isinstance(idx.type, NoneTypeT):
new_axes.append(len(explicit_indices))
explicit_indices.append(idx)
fake_index = tuple( if (len(explicit_indices) - len(new_axes)) > x.type.ndim:
chain.from_iterable( raise IndexError(
pytensor.tensor.basic.nonzero(idx) f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed"
if getattr(idx, "ndim", 0) > 0
and getattr(idx, "dtype", None) == "bool"
else (idx,)
for idx in index
)
) )
out_shape = tuple( # Perform basic and advanced indexing shape inference separately
i.value if isinstance(i, Constant) else None basic_group_shape = []
for i in indexed_result_shape(fake_shape, fake_index) advanced_indices = []
) adv_group_axis = None
last_adv_group_axis = None
expanded_x_shape = tuple(
np.insert(np.array(x.type.shape, dtype=object), 1, new_axes)
)
for i, (idx, dim_length) in enumerate(
zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)
):
if isinstance(idx.type, NoneTypeT):
basic_group_shape.append(1) # New-axis
elif isinstance(idx.type, SliceType):
if isinstance(idx, Constant):
basic_group_shape.append(slice_static_length(idx.data, dim_length))
elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice):
basic_group_shape.append(
slice_static_length(slice(*idx.owner.inputs), dim_length)
)
else:
# Symbolic root slice (owner is None), or slice operation we don't understand
basic_group_shape.append(None)
else: # TensorType
# Keep track of advanced group axis
if adv_group_axis is None:
# First time we see an advanced index
adv_group_axis, last_adv_group_axis = i, i
elif last_adv_group_axis == (i - 1):
# Another advanced indexing aligned with the first group
last_adv_group_axis = i
else:
# Non-consecutive advanced index, all advanced index views get moved to the front
adv_group_axis = 0
advanced_indices.append(idx)
if advanced_indices:
try:
# Use variadic add to infer static shape of advanced integer indices
advanced_group_static_shape = add(*advanced_indices).type.shape
except ValueError:
# It fails when static shapes are inconsistent
static_shapes = [idx.type.shape for idx in advanced_indices]
raise IndexError(
f"shape mismatch: indexing tensors could not be broadcast together with shapes {static_shapes}"
)
# Combine advanced and basic views
indexed_shape = [
*basic_group_shape[:adv_group_axis],
*advanced_group_static_shape,
*basic_group_shape[adv_group_axis:],
]
else:
# This could have been a basic subtensor!
indexed_shape = basic_group_shape
return Apply( return Apply(
self, self,
(x, *index), [x, *indices],
[tensor(dtype=x.type.dtype, shape=out_shape)], [tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))],
) )
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
...@@ -114,6 +114,9 @@ def as_symbolic_slice(x, **kwargs): ...@@ -114,6 +114,9 @@ def as_symbolic_slice(x, **kwargs):
return SliceConstant(slicetype, x) return SliceConstant(slicetype, x)
NoneSliceConst = Constant(slicetype, slice(None), name="slice(None)")
class NoneTypeT(Generic): class NoneTypeT(Generic):
""" """
Inherit from Generic to have c code working. Inherit from Generic to have c code working.
...@@ -137,4 +140,4 @@ def as_symbolic_None(x, **kwargs): ...@@ -137,4 +140,4 @@ def as_symbolic_None(x, **kwargs):
return NoneConst return NoneConst
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"] __all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst", "NoneSliceConst"]
...@@ -506,7 +506,9 @@ class _tensor_py_operators: ...@@ -506,7 +506,9 @@ class _tensor_py_operators:
# Check if the number of dimensions isn't too large. # Check if the number of dimensions isn't too large.
if self.ndim < index_dim_count: if self.ndim < index_dim_count:
raise IndexError("too many indices for array") raise IndexError(
f"too many indices for tensor: tensor is {self.ndim}-dimensional, but {index_dim_count} were indexed"
)
# Convert an Ellipsis if provided into an appropriate number of # Convert an Ellipsis if provided into an appropriate number of
# slice(None). # slice(None).
......
import logging import logging
import re
import sys import sys
from io import StringIO from io import StringIO
...@@ -1847,6 +1848,95 @@ class TestAdvancedSubtensor: ...@@ -1847,6 +1848,95 @@ class TestAdvancedSubtensor:
self.ix2 = lmatrix() self.ix2 = lmatrix()
self.ixr = lrow() self.ixr = lrow()
def test_static_shape(self):
x = tensor("x", shape=(None, None))
y = tensor("y", shape=(4, 5, 6))
idx1 = tensor("idx1", shape=(10,), dtype=int)
idx2 = tensor("idx2", shape=(3, None), dtype=int)
assert x[idx1].type.shape == (10, None)
assert x[:, idx1].type.shape == (None, 10)
assert x[idx2, :5].type.shape == (3, None, None)
assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5)
assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3)
assert x[idx1, idx2].type.shape == (3, 10)
assert x[idx2, idx1].type.shape == (3, 10)
assert x[None, idx1, idx2].type.shape == (1, 3, 10)
assert x[idx1, None, idx2].type.shape == (3, 10, 1)
assert x[idx1, idx2, None].type.shape == (3, 10, 1)
assert y[idx1, idx2, ::-1].type.shape == (3, 10, 6)
assert y[idx1, ::-1, idx2].type.shape == (3, 10, 5)
assert y[::-1, idx1, idx2].type.shape == (4, 3, 10)
assert y[::-1, idx1, None, idx2].type.shape == (3, 10, 4, 1)
msg = re.escape(
"shape mismatch: indexing tensors could not be broadcast together with shapes [(10,), (9,)]"
)
with pytest.raises(IndexError, match=msg):
x[idx1, idx1[1:]]
def test_static_shape_boolean(self):
y = tensor("y", shape=(4, 5, 6))
idx1 = tensor("idx1", shape=(4,), dtype=int)
idx2 = tensor("idx2", shape=(3, None), dtype=int)
bool_idx1 = tensor("bool_idx1", shape=(4,), dtype=bool)
bool_idx2 = tensor(
"bool_idx2",
shape=(
None,
5,
),
dtype=bool,
)
assert y[bool_idx1].type.shape == (None, 5, 6)
assert y[bool_idx1, :, None:-4:-1].type.shape == (None, 5, 3)
assert y[bool_idx1, idx2].type.shape == (3, None, 6)
assert y[bool_idx1, idx1, :].type.shape == (4, 6)
assert y[bool_idx1, :, idx1].type.shape == (4, 5)
assert y[bool_idx1, idx1, idx2].type.shape == (3, 4)
assert y[None, bool_idx1, None, idx2, None, idx1].type.shape == (3, 4, 1, 1, 1)
assert y[bool_idx2, :].type.shape == (None, 6)
assert y[bool_idx2, idx1].type.shape == (4,)
assert y[bool_idx2, idx2].type.shape == (3, None)
msg = re.escape(
"too many indices for tensor: tensor is 3-dimensional, but 4 were indexed"
)
with pytest.raises(IndexError, match=msg):
y[bool_idx2, bool_idx2]
# Case that could conceivably be detected as index error at definition time
bad_idx = ptb.concatenate([idx1, idx1])
assert y[bool_idx1, bad_idx].type.shape == (8, 6)
def test_static_shape_constant_boolean(self):
y = tensor("y", shape=(None, None, None))
idx1 = tensor("idx1", shape=(3,), dtype=int)
idx2 = tensor("idx2", shape=(4, None), dtype=int)
bool_idx1 = constant(np.array([True, False, True, True]), name="bool_idx1")
bool_idx2 = constant(
np.array([[True, False, True, True], [True, False, False, True]]),
name="bool_idx2",
)
assert y[bool_idx1].type.shape == (3, None, None)
assert y[bool_idx1, :, idx1].type.shape == (3, None)
assert y[bool_idx1, :, idx2].type.shape == (4, 3, None)
assert y[bool_idx2].type.shape == (5, None)
assert y[bool_idx1, idx2].type.shape == (4, 3, None)
bad_idx = ptb.concatenate([idx1, idx1])
msg = re.escape(
"shape mismatch: indexing tensors could not be broadcast together with shapes [(3,), (6,)]"
)
with pytest.raises(IndexError, match=msg):
y[bool_idx1, bad_idx]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inplace", "inplace",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论