提交 1944353c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify string expression for Constant Variables

上级 66934167
......@@ -763,13 +763,20 @@ class Constant(AtomicVariable[_TypeType]):
return (self.type, self.data)
def __str__(self):
if self.name is not None:
return self.name
else:
name = str(self.data)
if len(name) > 20:
name = name[:10] + "..." + name[-10:]
return f"{type(self).__name__}{{{name}}}"
data_str = str(self.data)
if len(data_str) > 20:
data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()
if self.name is None:
return data_str
return f"{self.name}{{{data_str}}}"
def __repr__(self):
data_str = repr(self.data)
if len(data_str) > 20:
data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()
return f"{type(self).__name__}({repr(self.type)}, data={data_str})"
def clone(self, **kwargs):
return self
......
......@@ -1023,21 +1023,6 @@ class TensorConstant(TensorVariable, Constant[_TensorTypeType]):
Constant.__init__(self, new_type, data, name)
def __str__(self):
unique_val = get_unique_value(self)
if unique_val is not None:
val = f"{self.data.shape} of {unique_val}"
else:
val = f"{self.data}"
if len(val) > 20:
val = val[:10].strip() + " ... " + val[-10:].strip()
if self.name is not None:
name = self.name
else:
name = "TensorConstant"
return f"{name}{{{val}}}"
def signature(self):
return TensorConstantSignature((self.type, self.data))
......
......@@ -166,7 +166,7 @@ class TestPatternNodeRewriter:
e = op1(op1(x, y), y)
g = FunctionGraph([y], [e])
OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))"
assert str(g) == "FunctionGraph(Op1(Op2(y, z{2}), y))"
def test_constraints(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
......
......@@ -213,7 +213,7 @@ def test_print_constant():
c = pytensor.tensor.constant(1, name="const")
assert str(c) == "const{1}"
d = pytensor.tensor.constant(1)
assert str(d) == "TensorConstant{1}"
assert str(d) == "1"
@pytest.mark.parametrize(
......
......@@ -280,10 +280,10 @@ def test_debugprint():
│ ├─ AllocEmpty{dtype='float64'} 1
│ │ └─ Shape_i{0} 0
│ │ └─ B
│ ├─ TensorConstant{1.0}
│ ├─ 1.0
│ ├─ B
│ ├─ <Vector(float64, shape=(?,))>
│ └─ TensorConstant{0.0}
│ └─ 0.0
├─ D
└─ A
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论