提交 b83398e6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add debugprint print_destroy_map and print_view_map

上级 b4c1bbd5
......@@ -116,6 +116,8 @@ def debugprint(
print_storage: bool = False,
used_ids: Optional[Dict[Variable, str]] = None,
print_op_info: bool = False,
print_destroy_map: bool = False,
print_view_map: bool = False,
) -> Union[str, IOBase]:
r"""Print a computation graph as text to stdout or a file.
......@@ -169,6 +171,10 @@ def debugprint(
print_op_info
Print extra information provided by the relevant `Op`\s. For example,
print the tap information for `Scan` inputs and outputs.
print_destroy_map
Whether to print the `destroy_map`\s of printed objects
print_view_map
Whether to print the `view_map`\s of printed objects
Returns
-------
......@@ -286,6 +292,8 @@ N.B.:
op_information=op_information,
parent_node=r.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
)
if len(inner_graph_ops) > 0:
......@@ -340,6 +348,8 @@ N.B.:
op_information=op_information,
parent_node=s.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
)
for idx, i in enumerate(inner_outputs):
......@@ -363,6 +373,8 @@ N.B.:
op_information=op_information,
parent_node=s.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
)
if file is _file:
......@@ -507,22 +519,16 @@ def _debugprint(
if r_name:
r_name = f" '{r_name}'"
if print_destroy_map:
destroy_map_str = str(r.owner.op.destroy_map)
if print_destroy_map and r.owner.op.destroy_map:
destroy_map_str = f" d={r.owner.op.destroy_map}"
else:
destroy_map_str = ""
if print_view_map:
view_map_str = str(r.owner.op.view_map)
if print_view_map and r.owner.op.view_map:
view_map_str = f" v={r.owner.op.view_map}"
else:
view_map_str = ""
if destroy_map_str and destroy_map_str != "{}":
destroy_map_str = f" d={destroy_map_str} "
if view_map_str and view_map_str != "{}":
view_map_str = f" v={view_map_str} "
if order:
o = f" {order.index(r.owner)}"
else:
......@@ -607,6 +613,8 @@ def _debugprint(
op_information=op_information,
parent_node=a,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
)
else:
......
......@@ -153,10 +153,6 @@ def test_debugprint():
+ "\n"
)
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference
# test ids=CHAR
......@@ -179,10 +175,6 @@ def test_debugprint():
+ "\n"
)
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference
# test ids=CHAR, stop_on_name=True
......@@ -203,10 +195,6 @@ def test_debugprint():
+ "\n"
)
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference
# test ids=
......@@ -228,9 +216,6 @@ def test_debugprint():
)
+ "\n"
)
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference
......@@ -250,12 +235,40 @@ def test_debugprint():
)
+ "\n"
)
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference
A = dmatrix(name="A")
B = dmatrix(name="B")
D = dmatrix(name="D")
J = dvector()
s = StringIO()
debugprint(
aesara.function([A, B, D, J], A + (B.dot(J) - D), mode="FAST_RUN"),
file=s,
ids="",
print_destroy_map=True,
print_view_map=True,
)
s = s.getvalue()
exp_res = r"""Elemwise{Composite{(i0 + (i1 - i2))}} 4
|A
|InplaceDimShuffle{x,0} v={0: [0]} 3
| |CGemv{inplace} d={0: [0]} 2
| |AllocEmpty{dtype='float64'} 1
| | |Shape_i{0} 0
| | |B
| |TensorConstant{1.0}
| |B
| |<TensorType(float64, (None,))>
| |TensorConstant{0.0}
|D
"""
assert [l.strip() for l in s.split("\n")] == [
l.strip() for l in exp_res.split("\n")
]
def test_debugprint_ids():
a_at = dvector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论