提交 f9dfe702 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Ricardo Vieira

Refactor `get_canonical_form_slice` to fix subtensor typing

上级 906e1424
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import chain, groupby from itertools import chain, groupby
from textwrap import dedent from textwrap import dedent
from typing import cast, overload
import numpy as np import numpy as np
...@@ -19,13 +20,19 @@ from pytensor.link.c.op import COp ...@@ -19,13 +20,19 @@ from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint, set_precedence from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import (
TensorLike,
_get_vector_length,
as_tensor_variable,
get_vector_length,
)
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
ScalarFromTensor, ScalarFromTensor,
alloc, alloc,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
nonzero, nonzero,
scalar_from_tensor,
) )
from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
...@@ -51,8 +58,14 @@ from pytensor.tensor.type import ( ...@@ -51,8 +58,14 @@ from pytensor.tensor.type import (
wscalar, wscalar,
zscalar, zscalar,
) )
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice from pytensor.tensor.type_other import (
from pytensor.tensor.variable import TensorVariable NoneConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
)
from pytensor.tensor.variable import TensorConstant, TensorVariable
_logger = logging.getLogger("pytensor.tensor.subtensor") _logger = logging.getLogger("pytensor.tensor.subtensor")
...@@ -134,7 +147,7 @@ def indices_from_subtensor( ...@@ -134,7 +147,7 @@ def indices_from_subtensor(
def as_index_constant( def as_index_constant(
a: slice | int | np.integer | Variable | None, a: slice | int | np.integer | Variable | None | TensorLike,
) -> Variable | slice | None: ) -> Variable | slice | None:
r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments. r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments.
...@@ -150,15 +163,41 @@ def as_index_constant( ...@@ -150,15 +163,41 @@ def as_index_constant(
) )
elif isinstance(a, int | np.integer): elif isinstance(a, int | np.integer):
return ps.ScalarConstant(ps.int64, a) return ps.ScalarConstant(ps.int64, a)
elif not isinstance(a, Variable): elif isinstance(a, Variable):
return as_tensor_variable(a)
else:
return a return a
return as_tensor_variable(a)
@overload
def as_index_literal(idx: int | np.integer) -> int | np.integer: ...
@overload
def as_index_literal(idx: None) -> None: ...
@overload
def as_index_literal(idx: slice | SliceConstant) -> slice: ...
@overload
def as_index_literal(idx: ScalarConstant | TensorConstant) -> int | np.integer: ...
@overload
def as_index_literal(idx: Variable): ...
def as_index_literal( def as_index_literal(
idx: Variable | slice | None, idx: None
) -> int | slice | None: | int
| np.integer
| slice
| SliceConstant
| ScalarConstant
| TensorConstant
| Variable,
) -> int | np.integer | slice | None:
"""Convert a symbolic index element to its Python equivalent. """Convert a symbolic index element to its Python equivalent.
This is like the inverse of `as_index_constant` This is like the inverse of `as_index_constant`
...@@ -167,22 +206,8 @@ def as_index_literal( ...@@ -167,22 +206,8 @@ def as_index_literal(
------ ------
NotScalarConstantError NotScalarConstantError
""" """
if idx == np.newaxis or isinstance(getattr(idx, "type", None), NoneTypeT): if idx is None or isinstance(idx, int | np.integer):
return np.newaxis return idx
if isinstance(idx, Constant):
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data
if isinstance(idx, Variable):
if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return as_index_literal(idx.owner.inputs[0])
if isinstance(idx.type, SliceType):
idx = slice(*idx.owner.inputs)
if isinstance(idx, slice): if isinstance(idx, slice):
return slice( return slice(
...@@ -191,6 +216,33 @@ def as_index_literal( ...@@ -191,6 +216,33 @@ def as_index_literal(
as_index_literal(idx.step), as_index_literal(idx.step),
) )
if not isinstance(idx, Variable):
raise TypeError(f"Not an index element: {idx}")
if isinstance(idx.type, NoneTypeT):
return None
if isinstance(idx, ScalarConstant):
return cast(int, idx.data)
if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return cast(int | np.integer, as_index_literal(idx.owner.inputs[0]))
if isinstance(idx, TensorConstant):
return cast(int, idx.data.item())
if isinstance(idx, SliceConstant):
return cast(slice, idx.data)
if isinstance(idx.type, SliceType):
assert idx.owner is not None
return slice(*map(as_index_literal, idx.owner.inputs))
# Other kinds of variables are not supported
raise NotScalarConstantError() raise NotScalarConstantError()
...@@ -198,10 +250,30 @@ def get_idx_list(inputs, idx_list): ...@@ -198,10 +250,30 @@ def get_idx_list(inputs, idx_list):
return indices_from_subtensor(inputs[1:], idx_list) return indices_from_subtensor(inputs[1:], idx_list)
@overload
def get_canonical_form_slice(
theslice: slice,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice, int | ScalarConstant]: ...
@overload
def get_canonical_form_slice(
theslice: int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[ScalarVariable, int]: ...
def get_canonical_form_slice( def get_canonical_form_slice(
theslice: slice | Variable, length: Variable theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[Variable, int]: length: int | np.integer | ScalarVariable | TensorVariable,
"""Convert slices to canonical form. ) -> tuple[slice | ScalarVariable, int | ScalarConstant]:
"""Convert indices or slices to canonical form.
Scalar integer indices or python Slices with Scalar/None attributes
used in basic Subtensor Ops are supported.
Symbolic slices (of SliceType) or vector indices
used in advanced Subtensor Ops are not supported.
Given a slice [start:stop:step] transform it into a canonical form Given a slice [start:stop:step] transform it into a canonical form
that respects the conventions imposed by python and numpy. that respects the conventions imposed by python and numpy.
...@@ -210,18 +282,28 @@ def get_canonical_form_slice( ...@@ -210,18 +282,28 @@ def get_canonical_form_slice(
in which 0 <= start <= stop <= length and step > 0, and a flag which says in which 0 <= start <= stop <= length and step > 0, and a flag which says
if the resulting set of numbers needs to be reversed or not. if the resulting set of numbers needs to be reversed or not.
Given a scalar index `idx` that may or not be negative, convert it to
a certainly positive form `idx if idx >= 0 else length + idx`.
Returns
-------
slc
Canonical form slice or scalar variable.
direction
Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
""" """
from pytensor.tensor import ge, lt, sign, switch from pytensor.tensor import ge, lt, sign, switch
# Other non-slice types are the scalar indexing case
if not isinstance(theslice, slice): if not isinstance(theslice, slice):
try: if isinstance(theslice, int | np.integer | ScalarVariable) or (
value = as_index_literal(theslice) isinstance(theslice, TensorVariable) and theslice.ndim == 0
except NotScalarConstantError: ):
value = theslice cano = switch(lt(theslice, 0), (theslice + length), theslice)
return scalar_from_tensor(cano), 1
value = switch(lt(value, 0), (value + length), value) raise ValueError(f"Slice {theslice} is not a supported slice type.")
return value, 1 # At this point we have a slice object. Possibly with symbolic inputs.
def analyze(x): def analyze(x):
try: try:
...@@ -243,6 +325,7 @@ def get_canonical_form_slice( ...@@ -243,6 +325,7 @@ def get_canonical_form_slice(
and is_step_constant and is_step_constant
and is_length_constant and is_length_constant
): ):
assert isinstance(length, int)
_start, _stop, _step = slice(start, stop, step).indices(length) _start, _stop, _step = slice(start, stop, step).indices(length)
if _start <= _stop and _step >= 1: if _start <= _stop and _step >= 1:
return slice(_start, _stop, _step), 1 return slice(_start, _stop, _step), 1
...@@ -2917,7 +3000,7 @@ def take(a, indices, axis=None, mode="raise"): ...@@ -2917,7 +3000,7 @@ def take(a, indices, axis=None, mode="raise"):
return a[full_indices] return a[full_indices]
@_get_vector_length.register(Subtensor) @_get_vector_length.register(Subtensor) # type: ignore
def _get_vector_length_Subtensor(op, var): def _get_vector_length_Subtensor(op, var):
# If we take a slice, we know how many elements it will result in # If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases. # TODO: We can cover more `*Subtensor` cases.
......
...@@ -25,7 +25,6 @@ pytensor/tensor/random/op.py ...@@ -25,7 +25,6 @@ pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py pytensor/tensor/rewriting/basic.py
pytensor/tensor/slinalg.py pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py
pytensor/tensor/type.py pytensor/tensor/type.py
pytensor/tensor/type_other.py pytensor/tensor/type_other.py
pytensor/tensor/variable.py pytensor/tensor/variable.py
...@@ -16,8 +16,8 @@ from pytensor.configdefaults import config ...@@ -16,8 +16,8 @@ from pytensor.configdefaults import config
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint from pytensor.printing import pprint
from pytensor.scalar.basic import as_scalar from pytensor.scalar.basic import as_scalar, int16
from pytensor.tensor import get_vector_length, vectorize from pytensor.tensor import as_tensor, get_vector_length, vectorize
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf from pytensor.tensor.math import exp, isinf
...@@ -69,7 +69,13 @@ from pytensor.tensor.type import ( ...@@ -69,7 +69,13 @@ from pytensor.tensor.type import (
tensor5, tensor5,
vector, vector,
) )
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype from pytensor.tensor.type_other import (
NoneConst,
SliceConstant,
as_symbolic_slice,
make_slice,
slicetype,
)
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.utils import inplace_func, integers_ranged, random from tests.tensor.utils import inplace_func, integers_ranged, random
...@@ -106,11 +112,51 @@ def test_as_index_literal(): ...@@ -106,11 +112,51 @@ def test_as_index_literal():
class TestGetCanonicalFormSlice: class TestGetCanonicalFormSlice:
@pytest.mark.parametrize(
"idx",
[
NoneConst,
None,
as_symbolic_slice(slice(3, 7, 2)),
as_symbolic_slice(slice(3, int16(), 2)),
vector(),
],
)
def test_unsupported_inputs(self, idx):
with pytest.raises(ValueError, match="not a supported slice"):
get_canonical_form_slice(idx, 5)
def test_scalar_constant(self): def test_scalar_constant(self):
a = as_scalar(0) a = as_scalar(0)
length = lscalar() length = lscalar()
res = get_canonical_form_slice(a, length) res = get_canonical_form_slice(a, length)
assert res[0].owner.op == ptb.switch assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
def test_tensor_constant(self):
a = as_tensor(0)
length = lscalar()
res = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
def test_symbolic_scalar(self):
a = int16()
length = lscalar()
res = get_canonical_form_slice(a, length)
assert res[0].owner.op, ptb.switch
assert res[1] == 1
def test_symbolic_tensor(self):
a = lscalar()
length = lscalar()
res = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
def test_all_integer(self):
res = get_canonical_form_slice(slice(1, 5, 2), 7)
assert isinstance(res[0], slice)
assert res[1] == 1 assert res[1] == 1
def test_all_symbolic(self): def test_all_symbolic(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论