Unverified 提交 1390cc39 authored 作者: Anirudh's avatar Anirudh 提交者: GitHub

Use time.perf_counter for aesara.function timing results

上级 471657a5
...@@ -824,7 +824,7 @@ class Function: ...@@ -824,7 +824,7 @@ class Function:
self[i] = value self[i] = value
profile = self.profile profile = self.profile
t0 = time.time() t0 = time.perf_counter()
output_subset = kwargs.pop("output_subset", None) output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None: if output_subset is not None and self.output_keys is not None:
...@@ -965,7 +965,7 @@ class Function: ...@@ -965,7 +965,7 @@ class Function:
) )
# Do the actual work # Do the actual work
t0_fn = time.time() t0_fn = time.perf_counter()
try: try:
outputs = ( outputs = (
self.vm() self.vm()
...@@ -991,7 +991,7 @@ class Function: ...@@ -991,7 +991,7 @@ class Function:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
raise raise
dt_fn = time.time() - t0_fn dt_fn = time.perf_counter() - t0_fn
self.maker.mode.fn_time += dt_fn self.maker.mode.fn_time += dt_fn
if profile: if profile:
profile.vm_call_time += dt_fn profile.vm_call_time += dt_fn
...@@ -1039,7 +1039,7 @@ class Function: ...@@ -1039,7 +1039,7 @@ class Function:
# grep for 'PROFILE_CODE' # grep for 'PROFILE_CODE'
# #
dt_call = time.time() - t0 dt_call = time.perf_counter() - t0
aesara.compile.profiling.total_fct_exec_time += dt_call aesara.compile.profiling.total_fct_exec_time += dt_call
self.maker.mode.call_time += dt_call self.maker.mode.call_time += dt_call
if profile: if profile:
...@@ -1395,7 +1395,7 @@ class FunctionMaker: ...@@ -1395,7 +1395,7 @@ class FunctionMaker:
): ):
try: try:
start_rewriter = time.time() start_rewriter = time.perf_counter()
rewriter_profile = None rewriter_profile = None
rewrite_time = None rewrite_time = None
...@@ -1406,7 +1406,7 @@ class FunctionMaker: ...@@ -1406,7 +1406,7 @@ class FunctionMaker:
): ):
rewriter_profile = rewriter(fgraph) rewriter_profile = rewriter(fgraph)
end_rewriter = time.time() end_rewriter = time.perf_counter()
rewrite_time = end_rewriter - start_rewriter rewrite_time = end_rewriter - start_rewriter
_logger.debug(f"Rewriting took {rewrite_time:f} seconds") _logger.debug(f"Rewriting took {rewrite_time:f} seconds")
...@@ -1416,7 +1416,7 @@ class FunctionMaker: ...@@ -1416,7 +1416,7 @@ class FunctionMaker:
# If the rewriter got interrupted # If the rewriter got interrupted
if rewrite_time is None: if rewrite_time is None:
end_rewriter = time.time() end_rewriter = time.perf_counter()
rewrite_time = end_rewriter - start_rewriter rewrite_time = end_rewriter - start_rewriter
aesara.compile.profiling.total_graph_rewrite_time += rewrite_time aesara.compile.profiling.total_graph_rewrite_time += rewrite_time
...@@ -1645,7 +1645,7 @@ class FunctionMaker: ...@@ -1645,7 +1645,7 @@ class FunctionMaker:
defaults.append((required, refeed, storage)) defaults.append((required, refeed, storage))
# Get a function instance # Get a function instance
start_linker = time.time() start_linker = time.perf_counter()
start_import_time = aesara.link.c.cmodule.import_time start_import_time = aesara.link.c.cmodule.import_time
with config.change_flags(traceback__limit=config.traceback__compile_limit): with config.change_flags(traceback__limit=config.traceback__compile_limit):
...@@ -1653,7 +1653,7 @@ class FunctionMaker: ...@@ -1653,7 +1653,7 @@ class FunctionMaker:
input_storage=input_storage_lists, storage_map=storage_map input_storage=input_storage_lists, storage_map=storage_map
) )
end_linker = time.time() end_linker = time.perf_counter()
linker_time = end_linker - start_linker linker_time = end_linker - start_linker
aesara.compile.profiling.total_time_linker += linker_time aesara.compile.profiling.total_time_linker += linker_time
...@@ -1725,7 +1725,7 @@ def orig_function( ...@@ -1725,7 +1725,7 @@ def orig_function(
""" """
t1 = time.time() t1 = time.perf_counter()
mode = aesara.compile.mode.get_mode(mode) mode = aesara.compile.mode.get_mode(mode)
inputs = list(map(convert_function_input, inputs)) inputs = list(map(convert_function_input, inputs))
...@@ -1758,7 +1758,7 @@ def orig_function( ...@@ -1758,7 +1758,7 @@ def orig_function(
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
fn = m.create(defaults) fn = m.create(defaults)
finally: finally:
t2 = time.time() t2 = time.perf_counter()
if fn and profile: if fn and profile:
profile.compile_time += t2 - t1 profile.compile_time += t2 - t1
# TODO: append # TODO: append
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论