提交 371048ea authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make sure `Type` subclasses implement the clone method

上级 b920bd29
......@@ -212,6 +212,10 @@ class Type(MetaObject):
"""
return self.Constant(type=self, data=value, name=name)
def clone(self, *args, **kwargs):
"""Clone a copy of this type with the given arguments/keyword values, if any."""
return type(self)(*args, **kwargs)
def __call__(self, name: Optional[Text] = None) -> Variable:
"""Return a new `Variable` instance of Type `self`.
......
......@@ -345,6 +345,11 @@ class Scalar(CType):
self.dtype = dtype
self.dtype_specs() # error checking
def clone(self, dtype=None, **kwargs):
if dtype is None:
dtype = self.dtype
return type(self)(dtype)
@staticmethod
def may_share_memory(a, b):
# This class represent basic c type, represented in python
......
......@@ -89,6 +89,13 @@ class SparseType(Type):
list(self.format_cls.keys()),
)
def clone(self, format=None, dtype=None, **kwargs):
if format is None:
format = self.format
if dtype is None:
dtype = self.dtype
return type(self)(format, dtype)
def filter(self, value, strict=False, allow_downcast=None):
if (
isinstance(value, self.format_cls[self.format])
......
......@@ -52,6 +52,9 @@ make_slice = MakeSlice()
class SliceType(Type):
def clone(self, **kwargs):
return type(self)()
def filter(self, x, strict=False, allow_downcast=None):
if isinstance(x, slice):
return x
......
......@@ -67,6 +67,11 @@ def test_convert_variable():
t1.convert_variable(v3)
def test_default_clone():
mt = MyType(1)
assert isinstance(mt.clone(1), MyType)
@pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
)
......
......@@ -60,3 +60,9 @@ def test_filter_float_subclass():
with config.change_flags(floatX="float64"):
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.floating)
def test_clone():
st = Scalar("int64")
assert st == st.clone()
assert st.clone("float64").dtype == "float64"
def test_sparse_type():
import aesara.sparse
from aesara.sparse.type import SparseType
# They need to be available even if scipy is not available.
assert hasattr(aesara.sparse, "SparseType")
def test_clone():
st = SparseType("csr", "float64")
assert st == st.clone()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论