提交 b17de1db authored 作者: AuguB's avatar AuguB 提交者: Ricardo Vieira

Correct usage of file context in profile.py

上级 222cd4ad
...@@ -62,7 +62,7 @@ def _atexit_print_fn(): ...@@ -62,7 +62,7 @@ def _atexit_print_fn():
else: else:
destination_file = config.profiling__destination destination_file = config.profiling__destination
with extended_open(destination_file, mode="w"): with extended_open(destination_file, mode="w") as f:
# Reverse sort in the order of compile+exec time # Reverse sort in the order of compile+exec time
for ps in sorted( for ps in sorted(
_atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time _atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time
...@@ -73,7 +73,7 @@ def _atexit_print_fn(): ...@@ -73,7 +73,7 @@ def _atexit_print_fn():
or getattr(ps, "callcount", 0) > 1 or getattr(ps, "callcount", 0) > 1
): ):
ps.summary( ps.summary(
file=destination_file, file=f,
n_ops_to_print=config.profiling__n_ops, n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply, n_apply_to_print=config.profiling__n_apply,
) )
...@@ -131,7 +131,7 @@ def _atexit_print_fn(): ...@@ -131,7 +131,7 @@ def _atexit_print_fn():
cum.rewriter_profile = None cum.rewriter_profile = None
cum.summary( cum.summary(
file=destination_file, file=f,
n_ops_to_print=config.profiling__n_ops, n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply, n_apply_to_print=config.profiling__n_apply,
) )
...@@ -157,7 +157,7 @@ def print_global_stats(): ...@@ -157,7 +157,7 @@ def print_global_stats():
else: else:
destination_file = config.profiling__destination destination_file = config.profiling__destination
with extended_open(destination_file, mode="w"): with extended_open(destination_file, mode="w") as f:
print("=" * 50, file=destination_file) print("=" * 50, file=destination_file)
print( print(
( (
...@@ -167,9 +167,9 @@ def print_global_stats(): ...@@ -167,9 +167,9 @@ def print_global_stats():
"Time spent compiling PyTensor functions: " "Time spent compiling PyTensor functions: "
f"rewriting = {total_graph_rewrite_time:6.3f}s, linking = {total_time_linker:6.3f}s ", f"rewriting = {total_graph_rewrite_time:6.3f}s, linking = {total_time_linker:6.3f}s ",
), ),
file=destination_file, file=f,
) )
print("=" * 50, file=destination_file) print("=" * 50, file=f)
_profiler_printers = [] _profiler_printers = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论