提交 5cfd9da9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate ScalarSharedVariable

上级 1091b441
......@@ -9,6 +9,18 @@ from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import _tensor_py_operators
def __getattr__(name):
if name == "ScalarSharedVariable":
warnings.warn(
"The class `ScalarSharedVariable` has been deprecated. "
"Use `TensorSharedVariable` instead and check for `ndim==0`.",
FutureWarning,
)
return TensorSharedVariable
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def load_shared_variable(val):
"""
This function is only here to keep some pickles loading
......@@ -94,10 +106,6 @@ def tensor_constructor(
)
class ScalarSharedVariable(TensorSharedVariable):
pass
@shared_constructor.register(np.number)
@shared_constructor.register(float)
@shared_constructor.register(int)
......@@ -132,7 +140,7 @@ def scalar_constructor(
# Do not pass the dtype to asarray because we want this to fail if
# strict is True and the types do not match.
rval = ScalarSharedVariable(
rval = TensorSharedVariable(
type=tensor_type,
value=np.array(value, copy=True),
name=name,
......
......@@ -10,7 +10,7 @@ from pytensor.misc.may_share_memory import may_share_memory
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.shape import Shape_i, specify_shape
from pytensor.tensor.sharedvar import ScalarSharedVariable, TensorSharedVariable
from pytensor.tensor.sharedvar import TensorSharedVariable
from tests import unittest_tools as utt
......@@ -679,12 +679,17 @@ def test_tensor_shared_zero():
def test_scalar_shared_options():
res = pytensor.shared(value=np.float32(0.0), name="lk", borrow=True)
assert isinstance(res, ScalarSharedVariable)
assert isinstance(res, TensorSharedVariable) and res.type.ndim == 0
assert res.type.dtype == "float32"
assert res.name == "lk"
assert res.type.shape == ()
def test_scalar_shared_deprecated():
with pytest.warns(FutureWarning, match=".*deprecated.*"):
pytensor.tensor.sharedvar.ScalarSharedVariable
def test_get_vector_length():
x = pytensor.shared(np.array((2, 3, 4, 5)))
assert get_vector_length(x) == 4
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论