提交 aa98984b authored 作者: notoraptor's avatar notoraptor

Print GpuArrayType like TensorType

(e.g. with "matrix" instead of "(False, False)").
上级 b551f3b1
......@@ -209,8 +209,23 @@ class GpuArrayType(Type):
return get_context(self.context_name)
def __repr__(self):
return "GpuArrayType<%s>(%s, %s)" % (self.context_name, self.dtype,
self.broadcastable)
# Inspired from TensorType.
if self.name:
return self.name
else:
b = self.broadcastable
named_broadcastable = {tuple(): 'scalar',
(False,): 'vector',
(False, True): 'col',
(True, False): 'row',
(False, False): 'matrix'}
if b in named_broadcastable:
bcast = named_broadcastable[b]
elif any(b):
bcast = str(b)
else:
bcast = '%iD' % len(b)
return "GpuArrayType<%s>(%s, %s)" % (self.context_name, self.dtype, bcast)
def filter(self, data, strict=False, allow_downcast=None):
return self.filter_inplace(data, None, strict=strict,
......
......@@ -373,7 +373,6 @@ class TensorType(Type):
def __repr__(self):
return str(self)
# "TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub, check_input=True):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论