提交 f9dfe702 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Ricardo Vieira

Refactor `get_canonical_form_slice` to fix subtensor typing

上级 906e1424
......@@ -3,6 +3,7 @@ import sys
from collections.abc import Callable, Iterable
from itertools import chain, groupby
from textwrap import dedent
from typing import cast, overload
import numpy as np
......@@ -19,13 +20,19 @@ from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import (
TensorLike,
_get_vector_length,
as_tensor_variable,
get_vector_length,
)
from pytensor.tensor.basic import (
ScalarFromTensor,
alloc,
get_underlying_scalar_constant_value,
nonzero,
scalar_from_tensor,
)
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
......@@ -51,8 +58,14 @@ from pytensor.tensor.type import (
wscalar,
zscalar,
)
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice
from pytensor.tensor.variable import TensorVariable
from pytensor.tensor.type_other import (
NoneConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
)
from pytensor.tensor.variable import TensorConstant, TensorVariable
_logger = logging.getLogger("pytensor.tensor.subtensor")
......@@ -134,7 +147,7 @@ def indices_from_subtensor(
def as_index_constant(
a: slice | int | np.integer | Variable | None,
a: slice | int | np.integer | Variable | None | TensorLike,
) -> Variable | slice | None:
r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments.
......@@ -150,15 +163,41 @@ def as_index_constant(
)
elif isinstance(a, int | np.integer):
return ps.ScalarConstant(ps.int64, a)
elif not isinstance(a, Variable):
return as_tensor_variable(a)
else:
elif isinstance(a, Variable):
return a
return as_tensor_variable(a)
@overload
def as_index_literal(idx: int | np.integer) -> int | np.integer: ...
@overload
def as_index_literal(idx: None) -> None: ...
@overload
def as_index_literal(idx: slice | SliceConstant) -> slice: ...
@overload
def as_index_literal(idx: ScalarConstant | TensorConstant) -> int | np.integer: ...
@overload
def as_index_literal(idx: Variable): ...
def as_index_literal(
idx: Variable | slice | None,
) -> int | slice | None:
idx: None
| int
| np.integer
| slice
| SliceConstant
| ScalarConstant
| TensorConstant
| Variable,
) -> int | np.integer | slice | None:
"""Convert a symbolic index element to its Python equivalent.
This is like the inverse of `as_index_constant`
......@@ -167,22 +206,8 @@ def as_index_literal(
------
NotScalarConstantError
"""
if idx == np.newaxis or isinstance(getattr(idx, "type", None), NoneTypeT):
return np.newaxis
if isinstance(idx, Constant):
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data
if isinstance(idx, Variable):
if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return as_index_literal(idx.owner.inputs[0])
if isinstance(idx.type, SliceType):
idx = slice(*idx.owner.inputs)
if idx is None or isinstance(idx, int | np.integer):
return idx
if isinstance(idx, slice):
return slice(
......@@ -191,6 +216,33 @@ def as_index_literal(
as_index_literal(idx.step),
)
if not isinstance(idx, Variable):
raise TypeError(f"Not an index element: {idx}")
if isinstance(idx.type, NoneTypeT):
return None
if isinstance(idx, ScalarConstant):
return cast(int, idx.data)
if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return cast(int | np.integer, as_index_literal(idx.owner.inputs[0]))
if isinstance(idx, TensorConstant):
return cast(int, idx.data.item())
if isinstance(idx, SliceConstant):
return cast(slice, idx.data)
if isinstance(idx.type, SliceType):
assert idx.owner is not None
return slice(*map(as_index_literal, idx.owner.inputs))
# Other kinds of variables are not supported
raise NotScalarConstantError()
......@@ -198,10 +250,30 @@ def get_idx_list(inputs, idx_list):
return indices_from_subtensor(inputs[1:], idx_list)
@overload
def get_canonical_form_slice(
theslice: slice,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice, int | ScalarConstant]: ...
@overload
def get_canonical_form_slice(
theslice: int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[ScalarVariable, int]: ...
def get_canonical_form_slice(
theslice: slice | Variable, length: Variable
) -> tuple[Variable, int]:
"""Convert slices to canonical form.
theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice | ScalarVariable, int | ScalarConstant]:
"""Convert indices or slices to canonical form.
Scalar integer indices or python Slices with Scalar/None attributes
used in basic Subtensor Ops are supported.
Symbolic slices (of SliceType) or vector indices
used in advanced Subtensor Ops are not supported.
Given a slice [start:stop:step] transform it into a canonical form
that respects the conventions imposed by python and numpy.
......@@ -210,18 +282,28 @@ def get_canonical_form_slice(
in which 0 <= start <= stop <= length and step > 0, and a flag which says
if the resulting set of numbers needs to be reversed or not.
Given a scalar index `idx` that may or not be negative, convert it to
a certainly positive form `idx if idx >= 0 else length + idx`.
Returns
-------
slc
Canonical form slice or scalar variable.
direction
Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
"""
from pytensor.tensor import ge, lt, sign, switch
# Other non-slice types are the scalar indexing case
if not isinstance(theslice, slice):
try:
value = as_index_literal(theslice)
except NotScalarConstantError:
value = theslice
value = switch(lt(value, 0), (value + length), value)
if isinstance(theslice, int | np.integer | ScalarVariable) or (
isinstance(theslice, TensorVariable) and theslice.ndim == 0
):
cano = switch(lt(theslice, 0), (theslice + length), theslice)
return scalar_from_tensor(cano), 1
raise ValueError(f"Slice {theslice} is not a supported slice type.")
return value, 1
# At this point we have a slice object. Possibly with symbolic inputs.
def analyze(x):
try:
......@@ -243,6 +325,7 @@ def get_canonical_form_slice(
and is_step_constant
and is_length_constant
):
assert isinstance(length, int)
_start, _stop, _step = slice(start, stop, step).indices(length)
if _start <= _stop and _step >= 1:
return slice(_start, _stop, _step), 1
......@@ -2917,7 +3000,7 @@ def take(a, indices, axis=None, mode="raise"):
return a[full_indices]
@_get_vector_length.register(Subtensor)
@_get_vector_length.register(Subtensor) # type: ignore
def _get_vector_length_Subtensor(op, var):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
......
......@@ -25,7 +25,6 @@ pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py
pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py
pytensor/tensor/type.py
pytensor/tensor/type_other.py
pytensor/tensor/variable.py
......@@ -16,8 +16,8 @@ from pytensor.configdefaults import config
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint
from pytensor.scalar.basic import as_scalar
from pytensor.tensor import get_vector_length, vectorize
from pytensor.scalar.basic import as_scalar, int16
from pytensor.tensor import as_tensor, get_vector_length, vectorize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf
......@@ -69,7 +69,13 @@ from pytensor.tensor.type import (
tensor5,
vector,
)
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
from pytensor.tensor.type_other import (
NoneConst,
SliceConstant,
as_symbolic_slice,
make_slice,
slicetype,
)
from tests import unittest_tools as utt
from tests.tensor.utils import inplace_func, integers_ranged, random
......@@ -106,11 +112,51 @@ def test_as_index_literal():
class TestGetCanonicalFormSlice:
@pytest.mark.parametrize(
"idx",
[
NoneConst,
None,
as_symbolic_slice(slice(3, 7, 2)),
as_symbolic_slice(slice(3, int16(), 2)),
vector(),
],
)
def test_unsupported_inputs(self, idx):
with pytest.raises(ValueError, match="not a supported slice"):
get_canonical_form_slice(idx, 5)
def test_scalar_constant(self):
a = as_scalar(0)
length = lscalar()
res = get_canonical_form_slice(a, length)
assert res[0].owner.op == ptb.switch
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
def test_tensor_constant(self):
a = as_tensor(0)
length = lscalar()
res = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
def test_symbolic_scalar(self):
a = int16()
length = lscalar()
res = get_canonical_form_slice(a, length)
assert res[0].owner.op, ptb.switch
assert res[1] == 1
def test_symbolic_tensor(self):
a = lscalar()
length = lscalar()
res = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
def test_all_integer(self):
res = get_canonical_form_slice(slice(1, 5, 2), 7)
assert isinstance(res[0], slice)
assert res[1] == 1
def test_all_symbolic(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论