提交 b83398e6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add debugprint print_destroy_map and print_view_map

上级 b4c1bbd5
...@@ -116,6 +116,8 @@ def debugprint( ...@@ -116,6 +116,8 @@ def debugprint(
print_storage: bool = False, print_storage: bool = False,
used_ids: Optional[Dict[Variable, str]] = None, used_ids: Optional[Dict[Variable, str]] = None,
print_op_info: bool = False, print_op_info: bool = False,
print_destroy_map: bool = False,
print_view_map: bool = False,
) -> Union[str, IOBase]: ) -> Union[str, IOBase]:
r"""Print a computation graph as text to stdout or a file. r"""Print a computation graph as text to stdout or a file.
...@@ -169,6 +171,10 @@ def debugprint( ...@@ -169,6 +171,10 @@ def debugprint(
print_op_info print_op_info
Print extra information provided by the relevant `Op`\s. For example, Print extra information provided by the relevant `Op`\s. For example,
print the tap information for `Scan` inputs and outputs. print the tap information for `Scan` inputs and outputs.
print_destroy_map
Whether to print the `destroy_map`\s of printed objects
print_view_map
Whether to print the `view_map`\s of printed objects
Returns Returns
------- -------
...@@ -286,6 +292,8 @@ N.B.: ...@@ -286,6 +292,8 @@ N.B.:
op_information=op_information, op_information=op_information,
parent_node=r.owner, parent_node=r.owner,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
) )
if len(inner_graph_ops) > 0: if len(inner_graph_ops) > 0:
...@@ -340,6 +348,8 @@ N.B.: ...@@ -340,6 +348,8 @@ N.B.:
op_information=op_information, op_information=op_information,
parent_node=s.owner, parent_node=s.owner,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
) )
for idx, i in enumerate(inner_outputs): for idx, i in enumerate(inner_outputs):
...@@ -363,6 +373,8 @@ N.B.: ...@@ -363,6 +373,8 @@ N.B.:
op_information=op_information, op_information=op_information,
parent_node=s.owner, parent_node=s.owner,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
) )
if file is _file: if file is _file:
...@@ -507,22 +519,16 @@ def _debugprint( ...@@ -507,22 +519,16 @@ def _debugprint(
if r_name: if r_name:
r_name = f" '{r_name}'" r_name = f" '{r_name}'"
if print_destroy_map: if print_destroy_map and r.owner.op.destroy_map:
destroy_map_str = str(r.owner.op.destroy_map) destroy_map_str = f" d={r.owner.op.destroy_map}"
else: else:
destroy_map_str = "" destroy_map_str = ""
if print_view_map: if print_view_map and r.owner.op.view_map:
view_map_str = str(r.owner.op.view_map) view_map_str = f" v={r.owner.op.view_map}"
else: else:
view_map_str = "" view_map_str = ""
if destroy_map_str and destroy_map_str != "{}":
destroy_map_str = f" d={destroy_map_str} "
if view_map_str and view_map_str != "{}":
view_map_str = f" v={view_map_str} "
if order: if order:
o = f" {order.index(r.owner)}" o = f" {order.index(r.owner)}"
else: else:
...@@ -607,6 +613,8 @@ def _debugprint( ...@@ -607,6 +613,8 @@ def _debugprint(
op_information=op_information, op_information=op_information,
parent_node=a, parent_node=a,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
) )
else: else:
......
...@@ -153,10 +153,6 @@ def test_debugprint(): ...@@ -153,10 +153,6 @@ def test_debugprint():
+ "\n" + "\n"
) )
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference assert s == reference
# test ids=CHAR # test ids=CHAR
...@@ -179,10 +175,6 @@ def test_debugprint(): ...@@ -179,10 +175,6 @@ def test_debugprint():
+ "\n" + "\n"
) )
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference assert s == reference
# test ids=CHAR, stop_on_name=True # test ids=CHAR, stop_on_name=True
...@@ -203,10 +195,6 @@ def test_debugprint(): ...@@ -203,10 +195,6 @@ def test_debugprint():
+ "\n" + "\n"
) )
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference assert s == reference
# test ids= # test ids=
...@@ -228,9 +216,6 @@ def test_debugprint(): ...@@ -228,9 +216,6 @@ def test_debugprint():
) )
+ "\n" + "\n"
) )
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference assert s == reference
...@@ -250,12 +235,40 @@ def test_debugprint(): ...@@ -250,12 +235,40 @@ def test_debugprint():
) )
+ "\n" + "\n"
) )
if s != reference:
print("--" + s + "--")
print("--" + reference + "--")
assert s == reference assert s == reference
A = dmatrix(name="A")
B = dmatrix(name="B")
D = dmatrix(name="D")
J = dvector()
s = StringIO()
debugprint(
aesara.function([A, B, D, J], A + (B.dot(J) - D), mode="FAST_RUN"),
file=s,
ids="",
print_destroy_map=True,
print_view_map=True,
)
s = s.getvalue()
exp_res = r"""Elemwise{Composite{(i0 + (i1 - i2))}} 4
|A
|InplaceDimShuffle{x,0} v={0: [0]} 3
| |CGemv{inplace} d={0: [0]} 2
| |AllocEmpty{dtype='float64'} 1
| | |Shape_i{0} 0
| | |B
| |TensorConstant{1.0}
| |B
| |<TensorType(float64, (None,))>
| |TensorConstant{0.0}
|D
"""
assert [l.strip() for l in s.split("\n")] == [
l.strip() for l in exp_res.split("\n")
]
def test_debugprint_ids(): def test_debugprint_ids():
a_at = dvector() a_at = dvector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论