提交 371048ea authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make sure `Type` subclasses implement the clone method

上级 b920bd29
...@@ -212,6 +212,10 @@ class Type(MetaObject): ...@@ -212,6 +212,10 @@ class Type(MetaObject):
""" """
return self.Constant(type=self, data=value, name=name) return self.Constant(type=self, data=value, name=name)
def clone(self, *args, **kwargs):
"""Clone a copy of this type with the given arguments/keyword values, if any."""
return type(self)(*args, **kwargs)
def __call__(self, name: Optional[Text] = None) -> Variable: def __call__(self, name: Optional[Text] = None) -> Variable:
"""Return a new `Variable` instance of Type `self`. """Return a new `Variable` instance of Type `self`.
......
...@@ -345,6 +345,11 @@ class Scalar(CType): ...@@ -345,6 +345,11 @@ class Scalar(CType):
self.dtype = dtype self.dtype = dtype
self.dtype_specs() # error checking self.dtype_specs() # error checking
def clone(self, dtype=None, **kwargs):
if dtype is None:
dtype = self.dtype
return type(self)(dtype)
@staticmethod @staticmethod
def may_share_memory(a, b): def may_share_memory(a, b):
# This class represent basic c type, represented in python # This class represent basic c type, represented in python
......
...@@ -89,6 +89,13 @@ class SparseType(Type): ...@@ -89,6 +89,13 @@ class SparseType(Type):
list(self.format_cls.keys()), list(self.format_cls.keys()),
) )
def clone(self, format=None, dtype=None, **kwargs):
if format is None:
format = self.format
if dtype is None:
dtype = self.dtype
return type(self)(format, dtype)
def filter(self, value, strict=False, allow_downcast=None): def filter(self, value, strict=False, allow_downcast=None):
if ( if (
isinstance(value, self.format_cls[self.format]) isinstance(value, self.format_cls[self.format])
......
...@@ -52,6 +52,9 @@ make_slice = MakeSlice() ...@@ -52,6 +52,9 @@ make_slice = MakeSlice()
class SliceType(Type): class SliceType(Type):
def clone(self, **kwargs):
return type(self)()
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
if isinstance(x, slice): if isinstance(x, slice):
return x return x
......
...@@ -67,6 +67,11 @@ def test_convert_variable(): ...@@ -67,6 +67,11 @@ def test_convert_variable():
t1.convert_variable(v3) t1.convert_variable(v3)
def test_default_clone():
mt = MyType(1)
assert isinstance(mt.clone(1), MyType)
@pytest.mark.skipif( @pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test." not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
) )
......
...@@ -60,3 +60,9 @@ def test_filter_float_subclass(): ...@@ -60,3 +60,9 @@ def test_filter_float_subclass():
with config.change_flags(floatX="float64"): with config.change_flags(floatX="float64"):
filtered_nan = test_type.filter(nan) filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.floating) assert isinstance(filtered_nan, np.floating)
def test_clone():
st = Scalar("int64")
assert st == st.clone()
assert st.clone("float64").dtype == "float64"
def test_sparse_type(): from aesara.sparse.type import SparseType
import aesara.sparse
# They need to be available even if scipy is not available.
assert hasattr(aesara.sparse, "SparseType") def test_clone():
st = SparseType("csr", "float64")
assert st == st.clone()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论