提交 84a540f1 authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Maxim Kochurov

remove deprecated stack interface

上级 f67638b6
...@@ -7,12 +7,9 @@ manipulation of tensors. ...@@ -7,12 +7,9 @@ manipulation of tensors.
import builtins import builtins
import warnings import warnings
from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
from typing import Sequence as TypeSequence
from typing import Tuple, Union
from typing import cast as type_cast from typing import cast as type_cast
import numpy as np import numpy as np
...@@ -1337,8 +1334,8 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None): ...@@ -1337,8 +1334,8 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
def infer_static_shape( def infer_static_shape(
shape: Union[Variable, TypeSequence[Union[Variable, int]]] shape: Union[Variable, Sequence[Union[Variable, int]]]
) -> Tuple[TypeSequence["TensorLike"], TypeSequence[Optional[int]]]: ) -> Tuple[Sequence["TensorLike"], Sequence[Optional[int]]]:
"""Infer the static shapes implied by the potentially symbolic elements in `shape`. """Infer the static shapes implied by the potentially symbolic elements in `shape`.
`shape` will be validated and constant folded. As a result, this function `shape` will be validated and constant folded. As a result, this function
...@@ -2538,19 +2535,16 @@ def roll(x, shift, axis=None): ...@@ -2538,19 +2535,16 @@ def roll(x, shift, axis=None):
) )
def stack(*tensors, **kwargs): def stack(tensors: Sequence[TensorVariable], axis: int = 0):
"""Stack tensors in sequence on given axis (default is 0). """Stack tensors in sequence on given axis (default is 0).
Take a sequence of tensors and stack them on given axis to make a single Take a sequence of tensors and stack them on given axis to make a single
tensor. The size in dimension `axis` of the result will be equal to the number tensor. The size in dimension `axis` of the result will be equal to the number
of tensors passed. of tensors passed.
Note: The interface stack(*tensors) is deprecated, you should use
stack(tensors, axis=0) instead.
Parameters Parameters
---------- ----------
tensors : list or tuple of tensors tensors : Sequence[TensorVariable]
A list of tensors to be stacked. A list of tensors to be stacked.
axis : int axis : int
The index of the new axis. Default value is 0. The index of the new axis. Default value is 0.
...@@ -2585,35 +2579,9 @@ def stack(*tensors, **kwargs): ...@@ -2585,35 +2579,9 @@ def stack(*tensors, **kwargs):
>>> rval.shape # 3 tensors are stacked on axis -2 >>> rval.shape # 3 tensors are stacked on axis -2
(2, 2, 2, 3, 2) (2, 2, 2, 3, 2)
""" """
# ---> Remove this when moving to the new interface: if not isinstance(tensors, Sequence):
if not tensors and not kwargs: raise TypeError("First argument should be Sequence[TensorVariable]")
raise ValueError("No tensor arguments provided") elif len(tensors) == 0:
if not kwargs and not isinstance(tensors[0], (list, tuple)):
warnings.warn(
"stack(*tensors) interface is deprecated, use"
" stack(tensors, axis=0) instead.",
DeprecationWarning,
stacklevel=3,
)
axis = 0
elif "tensors" in kwargs:
tensors = kwargs["tensors"]
if "axis" in kwargs:
axis = kwargs["axis"]
else:
axis = 0
else:
if len(tensors) == 2:
axis = tensors[1]
elif "axis" in kwargs:
axis = kwargs["axis"]
else:
axis = 0
tensors = tensors[0]
# <--- Until here.
if len(tensors) == 0:
raise ValueError("No tensor arguments provided") raise ValueError("No tensor arguments provided")
# If all tensors are scalars of the same type, call make_vector. # If all tensors are scalars of the same type, call make_vector.
...@@ -3662,8 +3630,8 @@ def swapaxes(y, axis1, axis2): ...@@ -3662,8 +3630,8 @@ def swapaxes(y, axis1, axis2):
def moveaxis( def moveaxis(
a: Union[np.ndarray, TensorVariable], a: Union[np.ndarray, TensorVariable],
source: Union[int, TypeSequence[int]], source: Union[int, Sequence[int]],
destination: Union[int, TypeSequence[int]], destination: Union[int, Sequence[int]],
) -> TensorVariable: ) -> TensorVariable:
"""Move axes of a TensorVariable to new positions. """Move axes of a TensorVariable to new positions.
......
...@@ -1332,8 +1332,10 @@ class TestJoinAndSplit: ...@@ -1332,8 +1332,10 @@ class TestJoinAndSplit:
with pytest.raises(IndexError): with pytest.raises(IndexError):
stack([a, b], -4) stack([a, b], -4)
# Testing depreciation warning # Testing depreciation warning is now an informative error
with pytest.warns(DeprecationWarning): with pytest.raises(
TypeError, match=r"First argument should be Sequence\[TensorVariable\]"
):
s = stack(a, b) s = stack(a, b)
def test_stack_hessian(self): def test_stack_hessian(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论