提交 debdd547 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a pretty printer for Constants

上级 61b1bbb9
......@@ -846,6 +846,18 @@ class LeafPrinter(Printer):
leaf_printer = LeafPrinter()
class ConstantPrinter(Printer):
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
r = str(output.data)
pstate.memo[output] = r
return r
constant_printer = ConstantPrinter()
class DefaultPrinter(Printer):
def process(self, output, pstate):
if output in pstate.memo:
......@@ -995,6 +1007,8 @@ else:
pprint = PPrinter()
pprint.assign(lambda pstate, r: True, default_printer)
pprint.assign(lambda pstate, r: isinstance(r, Constant), constant_printer)
pp = pprint
"""
......
......@@ -176,9 +176,9 @@ class TestComposite:
assert str(g) == (
"FunctionGraph(*1 -> Composite{((i0 + i1) + i2),"
" (i0 + (i1 * i2)), (i0 / i1), "
"(i0 // ScalarConstant{5}), "
"(i0 // 5), "
"(-i0), (i0 - i1), ((i0 ** i1) + (-i2)),"
" (i0 % ScalarConstant{3})}(x, y, z), "
" (i0 % 3)}(x, y, z), "
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
)
......
......@@ -281,7 +281,7 @@ def test_debugprint_ids():
def test_pprint():
x = dvector()
y = x[1]
assert pp(y) == "<TensorType(float64, vector)>[ScalarConstant{1}]"
assert pp(y) == "<TensorType(float64, vector)>[1]"
def test_debugprint_inner_graph():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论