Unverified 提交 20581c3f authored 作者: Oscar Gustafsson's avatar Oscar Gustafsson 提交者: GitHub

Make TensorConstants print their names when specified

In addition, the value is printed (and sometimes the data type).
上级 cb6fd028
...@@ -986,13 +986,17 @@ class TensorConstant(TensorVariable, Constant): ...@@ -986,13 +986,17 @@ class TensorConstant(TensorVariable, Constant):
def __str__(self): def __str__(self):
unique_val = get_unique_value(self) unique_val = get_unique_value(self)
if unique_val is not None: if unique_val is not None:
name = f"{self.data.shape} of {unique_val}" val = f"{self.data.shape} of {unique_val}"
else: else:
name = f"{self.data}" val = f"{self.data}"
if len(name) > 20: if len(val) > 20:
name = name[:10] + ".." + name[-10:] val = val[:10] + ".." + val[-10:]
return "TensorConstant{%s}" % name if self.name is not None:
name = self.name
else:
name = "TensorConstant"
return "%s{%s}" % (name, val)
def signature(self): def signature(self):
return TensorConstantSignature((self.type, self.data)) return TensorConstantSignature((self.type, self.data))
......
...@@ -196,6 +196,13 @@ def test__getitem__AdvancedSubtensor(): ...@@ -196,6 +196,13 @@ def test__getitem__AdvancedSubtensor():
assert op_types[-1] == AdvancedSubtensor assert op_types[-1] == AdvancedSubtensor
def test_print_constant():
c = aesara.tensor.constant(1, name="const")
assert str(c) == "const{1}"
d = aesara.tensor.constant(1)
assert str(d) == "TensorConstant{1}"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices, new_order", "x, indices, new_order",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论