Unverified 提交 20581c3f authored 作者: Oscar Gustafsson's avatar Oscar Gustafsson 提交者: GitHub

Make TensorConstants print their names when specified

In addition, the value is printed (and sometimes the data type).
上级 cb6fd028
......@@ -986,13 +986,17 @@ class TensorConstant(TensorVariable, Constant):
def __str__(self):
unique_val = get_unique_value(self)
if unique_val is not None:
name = f"{self.data.shape} of {unique_val}"
val = f"{self.data.shape} of {unique_val}"
else:
name = f"{self.data}"
if len(name) > 20:
name = name[:10] + ".." + name[-10:]
val = f"{self.data}"
if len(val) > 20:
val = val[:10] + ".." + val[-10:]
return "TensorConstant{%s}" % name
if self.name is not None:
name = self.name
else:
name = "TensorConstant"
return "%s{%s}" % (name, val)
def signature(self):
return TensorConstantSignature((self.type, self.data))
......
......@@ -196,6 +196,13 @@ def test__getitem__AdvancedSubtensor():
assert op_types[-1] == AdvancedSubtensor
def test_print_constant():
c = aesara.tensor.constant(1, name="const")
assert str(c) == "const{1}"
d = aesara.tensor.constant(1)
assert str(d) == "TensorConstant{1}"
@pytest.mark.parametrize(
"x, indices, new_order",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论