提交 1a31bb30 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow ellipsis in specify_shape helper

上级 499c3d41
......@@ -2,6 +2,7 @@ import warnings
from collections.abc import Sequence
from numbers import Number
from textwrap import dedent
from types import EllipsisType
from typing import TYPE_CHECKING, Union, cast
from typing import cast as typing_cast
......@@ -27,7 +28,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable
if TYPE_CHECKING:
from pytensor.tensor import TensorLike
ShapeValueType = None | np.integer | int | Variable
ShapeValueType = None | EllipsisType | np.integer | int | Variable
def register_shape_c_code(type, code, version=()):
......@@ -549,26 +550,37 @@ def specify_shape(
If a dimension's shape value is ``None``, the size of that dimension is not
considered fixed/static at runtime.
A single ``Ellipsis`` can be used to imply multiple ``None`` specified dimensions
"""
x = as_tensor_variable(x) # type: ignore[arg-type]
if not isinstance(shape, tuple | list):
shape = (shape,)
# If shape is a symbolic 1d vector of fixed length, we separate the items into a
# tuple with one entry per shape dimension
if len(shape) == 1 and shape[0] is not None:
shape_vector = ptb.as_tensor_variable(shape[0])
if len(shape) == 1 and shape[0] not in (None, Ellipsis):
shape_vector = ptb.as_tensor_variable(shape[0]) # type: ignore[arg-type]
if shape_vector.ndim == 1:
try:
shape = tuple(shape_vector)
except ValueError:
raise ValueError("Shape vector must have fixed dimensions")
if Ellipsis in shape:
ellipsis_pos = shape.index(Ellipsis)
implied_none = x.type.ndim - (len(shape) - 1)
shape = (
*shape[:ellipsis_pos],
*((None,) * implied_none),
*shape[ellipsis_pos + 1 :],
)
if Ellipsis in shape[ellipsis_pos + 1 :]:
raise ValueError("Multiple Ellipsis in specify_shape")
# If the specified shape is already encoded in the input static shape, do nothing
# This ignores PyTensor constants in shape
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
# The above is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
)
......
......@@ -480,6 +480,33 @@ class TestSpecifyShape(utt.InferShapeTester):
y = specify_shape(x, (None, 5))
assert y.type.shape == (3, 5)
def test_ellipsis(self):
x = tensor("x", shape=(None, None, None, None))
y = specify_shape(x, ...)
assert y.type.shape == (None, None, None, None)
y = specify_shape(x, (...,))
assert y.type.shape == (None, None, None, None)
y = specify_shape(x, (..., 5))
assert y.type.shape == (None, None, None, 5)
y = specify_shape(x, (5, ...))
assert y.type.shape == (5, None, None, None)
y = specify_shape(x, (5, ..., 3))
assert y.type.shape == (5, None, None, 3)
y = specify_shape(x, (5, ..., 3, None))
assert y.type.shape == (5, None, 3, None)
y = specify_shape(x, (5, 1, ..., 3, None))
assert y.type.shape == (5, 1, 3, None)
with pytest.raises(ValueError, match="Multiple Ellipsis in specify_shape"):
specify_shape(x, (..., None, ...))
def test_python_perform(self):
"""Test the Python `Op.perform` implementation."""
x = scalar()
......@@ -583,6 +610,8 @@ class TestSpecifyShape(utt.InferShapeTester):
assert specify_shape(x, (1, 2, None)) is x
assert specify_shape(x, (None, None, None)) is x
assert specify_shape(x, (...,)) is x
assert specify_shape(x, (..., None)) is x
assert specify_shape(x, (1, 2, 3)) is not x
assert specify_shape(x, (None, None, 3)) is not x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论