提交 c34f37f1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Reintroduce deprecated broadcastable keyword arguments

This should retain most backward compatibility.
上级 652e6e4a
import traceback import traceback
import warnings
import numpy as np import numpy as np
...@@ -38,6 +39,7 @@ def tensor_constructor( ...@@ -38,6 +39,7 @@ def tensor_constructor(
borrow=False, borrow=False,
shape=None, shape=None,
target="cpu", target="cpu",
broadcastable=None,
): ):
""" """
SharedVariable Constructor for TensorType. SharedVariable Constructor for TensorType.
...@@ -49,6 +51,13 @@ def tensor_constructor( ...@@ -49,6 +51,13 @@ def tensor_constructor(
optional `shape` argument will override this default. optional `shape` argument will override this default.
""" """
if broadcastable is not None:
warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.",
DeprecationWarning,
)
shape = broadcastable
if target != "cpu": if target != "cpu":
raise TypeError("not for cpu") raise TypeError("not for cpu")
......
import logging import logging
import warnings
from typing import Iterable, Optional, Union from typing import Iterable, Optional, Union
import numpy as np import numpy as np
...@@ -60,8 +61,9 @@ class TensorType(CType): ...@@ -60,8 +61,9 @@ class TensorType(CType):
def __init__( def __init__(
self, self,
dtype: Union[str, np.dtype], dtype: Union[str, np.dtype],
shape: Iterable[Optional[Union[bool, int]]], shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
name: Optional[str] = None, name: Optional[str] = None,
broadcastable: Optional[Iterable[bool]] = None,
): ):
r""" r"""
...@@ -79,6 +81,14 @@ class TensorType(CType): ...@@ -79,6 +81,14 @@ class TensorType(CType):
Optional name for this type. Optional name for this type.
""" """
if broadcastable is not None:
warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.",
DeprecationWarning,
)
shape = broadcastable
if isinstance(dtype, str) and dtype == "floatX": if isinstance(dtype, str) and dtype == "floatX":
self.dtype = config.floatX self.dtype = config.floatX
else: else:
...@@ -95,7 +105,13 @@ class TensorType(CType): ...@@ -95,7 +105,13 @@ class TensorType(CType):
self.name = name self.name = name
self.numpy_dtype = np.dtype(self.dtype) self.numpy_dtype = np.dtype(self.dtype)
def clone(self, dtype=None, shape=None, **kwargs): def clone(self, dtype=None, shape=None, broadcastable=None, **kwargs):
if broadcastable is not None:
warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.",
DeprecationWarning,
)
shape = broadcastable
if dtype is None: if dtype is None:
dtype = self.dtype dtype = self.dtype
if shape is None: if shape is None:
......
...@@ -685,3 +685,10 @@ def test_scalar_shared_options(): ...@@ -685,3 +685,10 @@ def test_scalar_shared_options():
def test_get_vector_length(): def test_get_vector_length():
x = aesara.shared(np.array((2, 3, 4, 5))) x = aesara.shared(np.array((2, 3, 4, 5)))
assert get_vector_length(x) == 4 assert get_vector_length(x) == 4
def test_deprecated_kwargs():
with pytest.warns(DeprecationWarning, match=".*broadcastable.*"):
res = aesara.shared(np.array([[1.0]]), broadcastable=(True, False))
assert res.type.shape == (1, None)
...@@ -286,3 +286,15 @@ def test_fixed_shape_convert_variable(): ...@@ -286,3 +286,15 @@ def test_fixed_shape_convert_variable():
res = t3.convert_variable(t4_var) res = t3.convert_variable(t4_var)
assert res.type == t4 assert res.type == t4
assert res.type.shape == (3, 2) assert res.type.shape == (3, 2)
def test_deprecated_kwargs():
with pytest.warns(DeprecationWarning, match=".*broadcastable.*"):
res = TensorType("float64", broadcastable=(True, False))
assert res.shape == (1, None)
with pytest.warns(DeprecationWarning, match=".*broadcastable.*"):
new_res = res.clone(broadcastable=(False, True))
assert new_res.shape == (None, 1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论