提交 10b84747 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Small tweaks to XTensorType

上级 0bb15f9d
...@@ -71,6 +71,8 @@ class XTensorType(Type, HasDataType, HasShape): ...@@ -71,6 +71,8 @@ class XTensorType(Type, HasDataType, HasShape):
self.name = name self.name = name
self.numpy_dtype = np.dtype(self.dtype) self.numpy_dtype = np.dtype(self.dtype)
self.filter_checks_isfinite = False self.filter_checks_isfinite = False
# broadcastable is here just for code that would work fine with XTensorType but checks for it
self.broadcastable = (False,) * self.ndim
def clone( def clone(
self, self,
...@@ -93,6 +95,10 @@ class XTensorType(Type, HasDataType, HasShape): ...@@ -93,6 +95,10 @@ class XTensorType(Type, HasDataType, HasShape):
self, value, strict=strict, allow_downcast=allow_downcast self, value, strict=strict, allow_downcast=allow_downcast
) )
@staticmethod
def may_share_memory(a, b):
return TensorType.may_share_memory(a, b)
def filter_variable(self, other, allow_convert=True): def filter_variable(self, other, allow_convert=True):
if not isinstance(other, Variable): if not isinstance(other, Variable):
# The value is not a Variable: we cast it into # The value is not a Variable: we cast it into
...@@ -160,7 +166,7 @@ class XTensorType(Type, HasDataType, HasShape): ...@@ -160,7 +166,7 @@ class XTensorType(Type, HasDataType, HasShape):
return None return None
def __repr__(self): def __repr__(self):
return f"XTensorType({self.dtype}, {self.dims}, {self.shape})" return f"XTensorType({self.dtype}, shape={self.shape}, dims={self.dims})"
def __hash__(self): def __hash__(self):
return hash((type(self), self.dtype, self.shape, self.dims)) return hash((type(self), self.dtype, self.shape, self.dims))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论