提交 1a31bb30 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow ellipsis in specify_shape helper

上级 499c3d41
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
from collections.abc import Sequence from collections.abc import Sequence
from numbers import Number from numbers import Number
from textwrap import dedent from textwrap import dedent
from types import EllipsisType
from typing import TYPE_CHECKING, Union, cast from typing import TYPE_CHECKING, Union, cast
from typing import cast as typing_cast from typing import cast as typing_cast
...@@ -27,7 +28,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable ...@@ -27,7 +28,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable
if TYPE_CHECKING: if TYPE_CHECKING:
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
ShapeValueType = None | np.integer | int | Variable ShapeValueType = None | EllipsisType | np.integer | int | Variable
def register_shape_c_code(type, code, version=()): def register_shape_c_code(type, code, version=()):
...@@ -549,26 +550,37 @@ def specify_shape( ...@@ -549,26 +550,37 @@ def specify_shape(
If a dimension's shape value is ``None``, the size of that dimension is not If a dimension's shape value is ``None``, the size of that dimension is not
considered fixed/static at runtime. considered fixed/static at runtime.
A single ``Ellipsis`` can be used to imply multiple ``None`` specified dimensions
""" """
x = as_tensor_variable(x) # type: ignore[arg-type]
if not isinstance(shape, tuple | list): if not isinstance(shape, tuple | list):
shape = (shape,) shape = (shape,)
# If shape is a symbolic 1d vector of fixed length, we separate the items into a # If shape is a symbolic 1d vector of fixed length, we separate the items into a
# tuple with one entry per shape dimension # tuple with one entry per shape dimension
if len(shape) == 1 and shape[0] is not None: if len(shape) == 1 and shape[0] not in (None, Ellipsis):
shape_vector = ptb.as_tensor_variable(shape[0]) shape_vector = ptb.as_tensor_variable(shape[0]) # type: ignore[arg-type]
if shape_vector.ndim == 1: if shape_vector.ndim == 1:
try: try:
shape = tuple(shape_vector) shape = tuple(shape_vector)
except ValueError: except ValueError:
raise ValueError("Shape vector must have fixed dimensions") raise ValueError("Shape vector must have fixed dimensions")
if Ellipsis in shape:
ellipsis_pos = shape.index(Ellipsis)
implied_none = x.type.ndim - (len(shape) - 1)
shape = (
*shape[:ellipsis_pos],
*((None,) * implied_none),
*shape[ellipsis_pos + 1 :],
)
if Ellipsis in shape[ellipsis_pos + 1 :]:
raise ValueError("Multiple Ellipsis in specify_shape")
# If the specified shape is already encoded in the input static shape, do nothing # If the specified shape is already encoded in the input static shape, do nothing
# This ignores PyTensor constants in shape # This ignores PyTensor constants in shape
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
# The above is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
new_shape_info = any( new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
) )
......
...@@ -480,6 +480,33 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -480,6 +480,33 @@ class TestSpecifyShape(utt.InferShapeTester):
y = specify_shape(x, (None, 5)) y = specify_shape(x, (None, 5))
assert y.type.shape == (3, 5) assert y.type.shape == (3, 5)
def test_ellipsis(self):
x = tensor("x", shape=(None, None, None, None))
y = specify_shape(x, ...)
assert y.type.shape == (None, None, None, None)
y = specify_shape(x, (...,))
assert y.type.shape == (None, None, None, None)
y = specify_shape(x, (..., 5))
assert y.type.shape == (None, None, None, 5)
y = specify_shape(x, (5, ...))
assert y.type.shape == (5, None, None, None)
y = specify_shape(x, (5, ..., 3))
assert y.type.shape == (5, None, None, 3)
y = specify_shape(x, (5, ..., 3, None))
assert y.type.shape == (5, None, 3, None)
y = specify_shape(x, (5, 1, ..., 3, None))
assert y.type.shape == (5, 1, 3, None)
with pytest.raises(ValueError, match="Multiple Ellipsis in specify_shape"):
specify_shape(x, (..., None, ...))
def test_python_perform(self): def test_python_perform(self):
"""Test the Python `Op.perform` implementation.""" """Test the Python `Op.perform` implementation."""
x = scalar() x = scalar()
...@@ -583,6 +610,8 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -583,6 +610,8 @@ class TestSpecifyShape(utt.InferShapeTester):
assert specify_shape(x, (1, 2, None)) is x assert specify_shape(x, (1, 2, None)) is x
assert specify_shape(x, (None, None, None)) is x assert specify_shape(x, (None, None, None)) is x
assert specify_shape(x, (...,)) is x
assert specify_shape(x, (..., None)) is x
assert specify_shape(x, (1, 2, 3)) is not x assert specify_shape(x, (1, 2, 3)) is not x
assert specify_shape(x, (None, None, 3)) is not x assert specify_shape(x, (None, None, 3)) is not x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论