提交 1944353c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify string expression for Constant Variables

上级 66934167
...@@ -763,13 +763,20 @@ class Constant(AtomicVariable[_TypeType]): ...@@ -763,13 +763,20 @@ class Constant(AtomicVariable[_TypeType]):
return (self.type, self.data) return (self.type, self.data)
def __str__(self): def __str__(self):
if self.name is not None: data_str = str(self.data)
return self.name if len(data_str) > 20:
else: data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()
name = str(self.data)
if len(name) > 20: if self.name is None:
name = name[:10] + "..." + name[-10:] return data_str
return f"{type(self).__name__}{{{name}}}"
return f"{self.name}{{{data_str}}}"
def __repr__(self):
data_str = repr(self.data)
if len(data_str) > 20:
data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()
return f"{type(self).__name__}({repr(self.type)}, data={data_str})"
def clone(self, **kwargs): def clone(self, **kwargs):
return self return self
......
...@@ -1023,21 +1023,6 @@ class TensorConstant(TensorVariable, Constant[_TensorTypeType]): ...@@ -1023,21 +1023,6 @@ class TensorConstant(TensorVariable, Constant[_TensorTypeType]):
Constant.__init__(self, new_type, data, name) Constant.__init__(self, new_type, data, name)
def __str__(self):
unique_val = get_unique_value(self)
if unique_val is not None:
val = f"{self.data.shape} of {unique_val}"
else:
val = f"{self.data}"
if len(val) > 20:
val = val[:10].strip() + " ... " + val[-10:].strip()
if self.name is not None:
name = self.name
else:
name = "TensorConstant"
return f"{name}{{{val}}}"
def signature(self): def signature(self):
return TensorConstantSignature((self.type, self.data)) return TensorConstantSignature((self.type, self.data))
......
...@@ -166,7 +166,7 @@ class TestPatternNodeRewriter: ...@@ -166,7 +166,7 @@ class TestPatternNodeRewriter:
e = op1(op1(x, y), y) e = op1(op1(x, y), y)
g = FunctionGraph([y], [e]) g = FunctionGraph([y], [e])
OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g) OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))" assert str(g) == "FunctionGraph(Op1(Op2(y, z{2}), y))"
def test_constraints(self): def test_constraints(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
......
...@@ -213,7 +213,7 @@ def test_print_constant(): ...@@ -213,7 +213,7 @@ def test_print_constant():
c = pytensor.tensor.constant(1, name="const") c = pytensor.tensor.constant(1, name="const")
assert str(c) == "const{1}" assert str(c) == "const{1}"
d = pytensor.tensor.constant(1) d = pytensor.tensor.constant(1)
assert str(d) == "TensorConstant{1}" assert str(d) == "1"
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -280,10 +280,10 @@ def test_debugprint(): ...@@ -280,10 +280,10 @@ def test_debugprint():
│ ├─ AllocEmpty{dtype='float64'} 1 │ ├─ AllocEmpty{dtype='float64'} 1
│ │ └─ Shape_i{0} 0 │ │ └─ Shape_i{0} 0
│ │ └─ B │ │ └─ B
│ ├─ TensorConstant{1.0} │ ├─ 1.0
│ ├─ B │ ├─ B
│ ├─ <Vector(float64, shape=(?,))> │ ├─ <Vector(float64, shape=(?,))>
│ └─ TensorConstant{0.0} │ └─ 0.0
├─ D ├─ D
└─ A └─ A
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论