提交 69c10443 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix _debugprint's handling of empty profile data

上级 eddc85fc
...@@ -630,11 +630,7 @@ def _debugprint( ...@@ -630,11 +630,7 @@ def _debugprint(
if node_info and var in node_info: if node_info and var in node_info:
var_output = f"{var_output} ({node_info[var]})" var_output = f"{var_output} ({node_info[var]})"
if profile is None: if profile and profile.apply_time and node in profile.apply_time:
print(var_output, file=file)
elif profile.apply_time and node not in profile.apply_time:
print(var_output, file=file)
elif profile.apply_time and node in profile.apply_time:
op_time = profile.apply_time[node] op_time = profile.apply_time[node]
op_time_percent = (op_time / profile.fct_call_time) * 100 op_time_percent = (op_time / profile.fct_call_time) * 100
tot_time_dict = profile.compute_total_times() tot_time_dict = profile.compute_total_times()
...@@ -652,6 +648,8 @@ def _debugprint( ...@@ -652,6 +648,8 @@ def _debugprint(
), ),
file=file, file=file,
) )
else:
print(var_output, file=file)
if not already_done and ( if not already_done and (
not stop_on_name or not (hasattr(var, "name") and var.name is not None) not stop_on_name or not (hasattr(var, "name") and var.name is not None)
......
...@@ -3,7 +3,9 @@ Tests of printing functionality ...@@ -3,7 +3,9 @@ Tests of printing functionality
""" """
import logging import logging
from io import StringIO from io import StringIO
from textwrap import dedent
import numpy as np
import pytest import pytest
import aesara import aesara
...@@ -121,12 +123,12 @@ def test_debugprint(): ...@@ -121,12 +123,12 @@ def test_debugprint():
with pytest.raises(TypeError): with pytest.raises(TypeError):
debugprint("blah") debugprint("blah")
A = matrix(name="A") A = dmatrix(name="A")
B = matrix(name="B") B = dmatrix(name="B")
C = A + B C = A + B
C.name = "C" C.name = "C"
D = matrix(name="D") D = dmatrix(name="D")
E = matrix(name="E") E = dmatrix(name="E")
F = D + E F = D + E
G = C + F G = C + F
...@@ -140,21 +142,17 @@ def test_debugprint(): ...@@ -140,21 +142,17 @@ def test_debugprint():
s = StringIO() s = StringIO()
debugprint(G, file=s, id_type="int") debugprint(G, file=s, id_type="int")
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! reference = dedent(
reference = ( r"""
"\n".join( Elemwise{add,no_inplace} [id 0]
[ |Elemwise{add,no_inplace} [id 1] 'C'
"Elemwise{add,no_inplace} [id 0]", | |A [id 2]
" |Elemwise{add,no_inplace} [id 1] 'C'", | |B [id 3]
" | |A [id 2]", |Elemwise{add,no_inplace} [id 4]
" | |B [id 3]", |D [id 5]
" |Elemwise{add,no_inplace} [id 4]", |E [id 6]
" |D [id 5]", """
" |E [id 6]", ).lstrip()
]
)
+ "\n"
)
assert s == reference assert s == reference
...@@ -162,20 +160,17 @@ def test_debugprint(): ...@@ -162,20 +160,17 @@ def test_debugprint():
debugprint(G, file=s, id_type="CHAR") debugprint(G, file=s, id_type="CHAR")
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = dedent(
"\n".join( r"""
[ Elemwise{add,no_inplace} [id A]
"Elemwise{add,no_inplace} [id A]", |Elemwise{add,no_inplace} [id B] 'C'
" |Elemwise{add,no_inplace} [id B] 'C'", | |A [id C]
" | |A [id C]", | |B [id D]
" | |B [id D]", |Elemwise{add,no_inplace} [id E]
" |Elemwise{add,no_inplace} [id E]", |D [id F]
" |D [id F]", |E [id G]
" |E [id G]", """
] ).lstrip()
)
+ "\n"
)
assert s == reference assert s == reference
...@@ -183,61 +178,86 @@ def test_debugprint(): ...@@ -183,61 +178,86 @@ def test_debugprint():
debugprint(G, file=s, id_type="CHAR", stop_on_name=True) debugprint(G, file=s, id_type="CHAR", stop_on_name=True)
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = dedent(
"\n".join( r"""
[ Elemwise{add,no_inplace} [id A]
"Elemwise{add,no_inplace} [id A]", |Elemwise{add,no_inplace} [id B] 'C'
" |Elemwise{add,no_inplace} [id B] 'C'", |Elemwise{add,no_inplace} [id C]
" |Elemwise{add,no_inplace} [id C]", |D [id D]
" |D [id D]", |E [id E]
" |E [id E]", """
] ).lstrip()
)
+ "\n"
)
assert s == reference assert s == reference
s = StringIO() s = StringIO()
debugprint(G, file=s, id_type="") debugprint(G, file=s, id_type="")
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! reference = dedent(
reference = ( r"""
"\n".join( Elemwise{add,no_inplace}
[ |Elemwise{add,no_inplace} 'C'
"Elemwise{add,no_inplace}", | |A
" |Elemwise{add,no_inplace} 'C'", | |B
" | |A", |Elemwise{add,no_inplace}
" | |B", |D
" |Elemwise{add,no_inplace}", |E
" |D", """
" |E", ).lstrip()
]
)
+ "\n"
)
assert s == reference assert s == reference
# test print_storage=True
s = StringIO() s = StringIO()
debugprint(g, file=s, id_type="", print_storage=True) debugprint(g, file=s, id_type="", print_storage=True)
s = s.getvalue() s = s.getvalue()
reference = ( reference = dedent(
"\n".join( r"""
[ Elemwise{add,no_inplace} 0 [None]
"Elemwise{add,no_inplace} 0 [None]", |A [None]
" |A [None]", |B [None]
" |B [None]", |D [None]
" |D [None]", |E [None]
" |E [None]", """
] ).lstrip()
)
+ "\n"
)
assert s == reference assert s == reference
# Test the `profile` handling when profile data is missing
g = aesara.function([A, B, D, E], G, mode=mode, profile=True)
s = StringIO()
debugprint(g, file=s, id_type="", print_storage=True)
s = s.getvalue()
reference = dedent(
r"""
Elemwise{add,no_inplace} 0 [None]
|A [None]
|B [None]
|D [None]
|E [None]
"""
).lstrip()
assert s == reference
# Add profile data
g(np.c_[[1.0]], np.c_[[1.0]], np.c_[[1.0]], np.c_[[1.0]])
s = StringIO()
debugprint(g, file=s, id_type="", print_storage=True)
s = s.getvalue()
reference = dedent(
r"""
Elemwise{add,no_inplace} 0 [None]
|A [None]
|B [None]
|D [None]
|E [None]
"""
).lstrip()
assert reference in s
A = dmatrix(name="A") A = dmatrix(name="A")
B = dmatrix(name="B") B = dmatrix(name="B")
D = dmatrix(name="D") D = dmatrix(name="D")
...@@ -251,7 +271,9 @@ def test_debugprint(): ...@@ -251,7 +271,9 @@ def test_debugprint():
print_view_map=True, print_view_map=True,
) )
s = s.getvalue() s = s.getvalue()
exp_res = r"""Elemwise{Composite{(i0 + (i1 - i2))}} 4 exp_res = dedent(
r"""
Elemwise{Composite{(i0 + (i1 - i2))}} 4
|A |A
|InplaceDimShuffle{x,0} v={0: [0]} 3 |InplaceDimShuffle{x,0} v={0: [0]} 3
| |CGemv{inplace} d={0: [0]} 2 | |CGemv{inplace} d={0: [0]} 2
...@@ -264,6 +286,7 @@ def test_debugprint(): ...@@ -264,6 +286,7 @@ def test_debugprint():
| |TensorConstant{0.0} | |TensorConstant{0.0}
|D |D
""" """
).lstrip()
assert [l.strip() for l in s.split("\n")] == [ assert [l.strip() for l in s.split("\n")] == [
l.strip() for l in exp_res.split("\n") l.strip() for l in exp_res.split("\n")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论