提交 84a540f1 authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Maxim Kochurov

remove deprecated stack interface

上级 f67638b6
......@@ -7,12 +7,9 @@ manipulation of tensors.
import builtins
import warnings
from collections.abc import Sequence
from functools import partial
from numbers import Number
from typing import TYPE_CHECKING, Optional
from typing import Sequence as TypeSequence
from typing import Tuple, Union
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
from typing import cast as type_cast
import numpy as np
......@@ -1337,8 +1334,8 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
def infer_static_shape(
shape: Union[Variable, TypeSequence[Union[Variable, int]]]
) -> Tuple[TypeSequence["TensorLike"], TypeSequence[Optional[int]]]:
shape: Union[Variable, Sequence[Union[Variable, int]]]
) -> Tuple[Sequence["TensorLike"], Sequence[Optional[int]]]:
"""Infer the static shapes implied by the potentially symbolic elements in `shape`.
`shape` will be validated and constant folded. As a result, this function
......@@ -2538,19 +2535,16 @@ def roll(x, shift, axis=None):
)
def stack(*tensors, **kwargs):
def stack(tensors: Sequence[TensorVariable], axis: int = 0):
"""Stack tensors in sequence on given axis (default is 0).
Take a sequence of tensors and stack them on given axis to make a single
tensor. The size in dimension `axis` of the result will be equal to the number
of tensors passed.
Note: The interface stack(*tensors) is deprecated, you should use
stack(tensors, axis=0) instead.
Parameters
----------
tensors : list or tuple of tensors
tensors : Sequence[TensorVariable]
A list of tensors to be stacked.
axis : int
The index of the new axis. Default value is 0.
......@@ -2585,35 +2579,9 @@ def stack(*tensors, **kwargs):
>>> rval.shape # 3 tensors are stacked on axis -2
(2, 2, 2, 3, 2)
"""
# ---> Remove this when moving to the new interface:
if not tensors and not kwargs:
raise ValueError("No tensor arguments provided")
if not kwargs and not isinstance(tensors[0], (list, tuple)):
warnings.warn(
"stack(*tensors) interface is deprecated, use"
" stack(tensors, axis=0) instead.",
DeprecationWarning,
stacklevel=3,
)
axis = 0
elif "tensors" in kwargs:
tensors = kwargs["tensors"]
if "axis" in kwargs:
axis = kwargs["axis"]
else:
axis = 0
else:
if len(tensors) == 2:
axis = tensors[1]
elif "axis" in kwargs:
axis = kwargs["axis"]
else:
axis = 0
tensors = tensors[0]
# <--- Until here.
if len(tensors) == 0:
if not isinstance(tensors, Sequence):
raise TypeError("First argument should be Sequence[TensorVariable]")
elif len(tensors) == 0:
raise ValueError("No tensor arguments provided")
# If all tensors are scalars of the same type, call make_vector.
......@@ -3662,8 +3630,8 @@ def swapaxes(y, axis1, axis2):
def moveaxis(
a: Union[np.ndarray, TensorVariable],
source: Union[int, TypeSequence[int]],
destination: Union[int, TypeSequence[int]],
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
) -> TensorVariable:
"""Move axes of a TensorVariable to new positions.
......
......@@ -1332,8 +1332,10 @@ class TestJoinAndSplit:
with pytest.raises(IndexError):
stack([a, b], -4)
# Testing depreciation warning
with pytest.warns(DeprecationWarning):
# Testing depreciation warning is now an informative error
with pytest.raises(
TypeError, match=r"First argument should be Sequence\[TensorVariable\]"
):
s = stack(a, b)
def test_stack_hessian(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论