提交 66934167 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Specialize TensorType string representation

上级 88516d2e
......@@ -386,6 +386,8 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
if self.name:
return self.name
else:
shape = self.shape
len_shape = len(shape)
def shape_str(s):
if s is None:
......@@ -393,14 +395,18 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
else:
return str(s)
formatted_shape = ", ".join([shape_str(s) for s in self.shape])
if len(self.shape) == 1:
formatted_shape = ", ".join([shape_str(s) for s in shape])
if len_shape == 1:
formatted_shape += ","
return f"TensorType({self.dtype}, ({formatted_shape}))"
if len_shape > 2:
name = f"Tensor{len_shape}"
else:
name = ("Scalar", "Vector", "Matrix")[len_shape]
return f"{name}({self.dtype}, shape=({formatted_shape}))"
def __repr__(self):
return str(self)
return f"TensorType({self.dtype}, shape={self.shape})"
@staticmethod
def may_share_memory(a, b):
......
......@@ -1030,7 +1030,7 @@ class TensorConstant(TensorVariable, Constant[_TensorTypeType]):
else:
val = f"{self.data}"
if len(val) > 20:
val = val[:10] + ".." + val[-10:]
val = val[:10].strip() + " ... " + val[-10:].strip()
if self.name is not None:
name = self.name
......
......@@ -580,10 +580,10 @@ Inner graphs:
OpFromGraph{inline=False} [id A]
← Add [id E]
├─ *0-<TensorType(float64, (?, ?))> [id F]
├─ *0-<Matrix(float64, shape=(?, ?))> [id F]
└─ Mul [id G]
├─ *1-<TensorType(float64, (?, ?))> [id H]
└─ *2-<TensorType(float64, (?, ?))> [id I]
├─ *1-<Matrix(float64, shape=(?, ?))> [id H]
└─ *2-<Matrix(float64, shape=(?, ?))> [id I]
"""
for truth, out in zip(exp_res.split("\n"), lines):
......
......@@ -252,7 +252,7 @@ def test_fixed_shape_basic():
assert t1.shape == (2, 3)
assert t1.broadcastable == (False, False)
assert str(t1) == "TensorType(float64, (2, 3))"
assert str(t1) == "Matrix(float64, shape=(2, 3))"
t1 = TensorType("float64", shape=(1,))
assert t1.shape == (1,)
......
......@@ -282,7 +282,7 @@ def test_debugprint():
│ │ └─ B
│ ├─ TensorConstant{1.0}
│ ├─ B
│ ├─ <TensorType(float64, (?,))>
│ ├─ <Vector(float64, shape=(?,))>
│ └─ TensorConstant{0.0}
├─ D
└─ A
......@@ -316,9 +316,9 @@ def test_debugprint_id_type():
exp_res = f"""Add [id {e_at.auto_name}]
├─ dot [id {d_at.auto_name}]
│ ├─ <TensorType(float64, (?, ?))> [id {b_at.auto_name}]
│ └─ <TensorType(float64, (?,))> [id {a_at.auto_name}]
└─ <TensorType(float64, (?,))> [id {a_at.auto_name}]
│ ├─ <Matrix(float64, shape=(?, ?))> [id {b_at.auto_name}]
│ └─ <Vector(float64, shape=(?,))> [id {a_at.auto_name}]
└─ <Vector(float64, shape=(?,))> [id {a_at.auto_name}]
"""
assert [l.strip() for l in s.split("\n")] == [
......@@ -329,7 +329,7 @@ def test_debugprint_id_type():
def test_pprint():
x = dvector()
y = x[1]
assert pp(y) == "<TensorType(float64, (?,))>[1]"
assert pp(y) == "<Vector(float64, shape=(?,))>[1]"
def test_debugprint_inner_graph():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论