提交 c34f37f1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Reintroduce deprecated broadcastable keyword arguments

This should retain most backward compatibility.
上级 652e6e4a
import traceback
import warnings
import numpy as np
......@@ -38,6 +39,7 @@ def tensor_constructor(
borrow=False,
shape=None,
target="cpu",
broadcastable=None,
):
"""
SharedVariable Constructor for TensorType.
......@@ -49,6 +51,13 @@ def tensor_constructor(
optional `shape` argument will override this default.
"""
if broadcastable is not None:
warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.",
DeprecationWarning,
)
shape = broadcastable
if target != "cpu":
raise TypeError("not for cpu")
......
import logging
import warnings
from typing import Iterable, Optional, Union
import numpy as np
......@@ -60,8 +61,9 @@ class TensorType(CType):
def __init__(
self,
dtype: Union[str, np.dtype],
shape: Iterable[Optional[Union[bool, int]]],
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
name: Optional[str] = None,
broadcastable: Optional[Iterable[bool]] = None,
):
r"""
......@@ -79,6 +81,14 @@ class TensorType(CType):
Optional name for this type.
"""
if broadcastable is not None:
warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.",
DeprecationWarning,
)
shape = broadcastable
if isinstance(dtype, str) and dtype == "floatX":
self.dtype = config.floatX
else:
......@@ -95,7 +105,13 @@ class TensorType(CType):
self.name = name
self.numpy_dtype = np.dtype(self.dtype)
def clone(self, dtype=None, shape=None, **kwargs):
def clone(self, dtype=None, shape=None, broadcastable=None, **kwargs):
if broadcastable is not None:
warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.",
DeprecationWarning,
)
shape = broadcastable
if dtype is None:
dtype = self.dtype
if shape is None:
......
......@@ -685,3 +685,10 @@ def test_scalar_shared_options():
def test_get_vector_length():
x = aesara.shared(np.array((2, 3, 4, 5)))
assert get_vector_length(x) == 4
def test_deprecated_kwargs():
with pytest.warns(DeprecationWarning, match=".*broadcastable.*"):
res = aesara.shared(np.array([[1.0]]), broadcastable=(True, False))
assert res.type.shape == (1, None)
......@@ -286,3 +286,15 @@ def test_fixed_shape_convert_variable():
res = t3.convert_variable(t4_var)
assert res.type == t4
assert res.type.shape == (3, 2)
def test_deprecated_kwargs():
with pytest.warns(DeprecationWarning, match=".*broadcastable.*"):
res = TensorType("float64", broadcastable=(True, False))
assert res.shape == (1, None)
with pytest.warns(DeprecationWarning, match=".*broadcastable.*"):
new_res = res.clone(broadcastable=(False, True))
assert new_res.shape == (None, 1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论