提交 af069dbb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move aesara.compile.profiling.ScanProfileStats to aesara.scan.utils

上级 dee152ce
......@@ -59,5 +59,5 @@ from aesara.compile.ops import (
register_view_op_c_code,
view_op,
)
from aesara.compile.profiling import ProfileStats, ScanProfileStats
from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared, shared_constructor
......@@ -26,11 +26,6 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
__authors__ = "James Bergstra " "PyMC Developers " "Aesara Developers "
__copyright__ = "(c) 2011, Universite de Montreal"
__docformat__ = "restructuredtext en"
logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time()
......@@ -71,7 +66,8 @@ def _atexit_print_fn():
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)
if not isinstance(ps, ScanProfileStats):
if ps.show_sum:
to_sum.append(ps)
else:
# TODO print the name if there is one!
......@@ -79,10 +75,7 @@ def _atexit_print_fn():
if len(to_sum) > 1:
# Make a global profile
cum = copy.copy(to_sum[0])
msg = (
f"Sum of all({len(to_sum)}) printed profiles at exit excluding Scan op"
" profile."
)
msg = f"Sum of all({len(to_sum)}) printed profiles at exit."
cum.message = msg
for ps in to_sum[1:]:
for attr in [
......@@ -209,6 +202,7 @@ class ProfileStats:
# documented and initialized together.
# dictionary variables are initialized with None.
#
show_sum: bool = True
compile_time = 0.0
# Total time spent in body of orig_function,
......@@ -1729,61 +1723,3 @@ class ProfileStats:
]
for f in _profiler_printers:
f(*params, file=file)
class ScanProfileStats(ProfileStats):
callcount = 0.0
nbsteps = 0.0
call_time = 0.0
def __init__(self, atexit_print=True, name=None, **kwargs):
super().__init__(atexit_print, **kwargs)
self.name = name
def summary_globals(self, file):
# Do nothing, we don't want to print extra global summary
# here.
pass
def summary_function(self, file):
# RP: every time we compile a function a ProfileStats is created for
# that function. This means that every time a optimization replaces
# some scan op, some orphane ProfileStats remains in the air ..
# also even without any optimization, scan compiles a dummy function
# that will produce a ProfileStats that will correspond to a
# function that will never be called. Printing several empty
# Function profiling is just extremely confusing
if self.callcount == 0:
return
print("", file=file)
if self.name is not None:
print("Scan Op profiling (", self.name, ")", file=file)
else:
print("Scan Op profiling", file=file)
print("==================", file=file)
print(f" Message: {self.message}", file=file)
print(
(
f" Time in {self.callcount} calls of the op (for a total of {self.nbsteps} "
f"steps) {self.call_time:3}s"
),
file=file,
)
print("", file=file)
val = 0
if self.call_time > 0:
val = self.vm_call_time * 100 / self.call_time
print(
f" Total time spent in calling the VM {self.vm_call_time:e}s ({val:.3f}%)",
file=file,
)
val = 100
if self.call_time > 0:
val = 100.0 - self.vm_call_time * 100 / self.call_time
print(
f" Total overhead (computing slices..) {self.call_time - self.vm_call_time:e}s ({val:.3f}%)",
file=file,
)
print("", file=file)
......@@ -59,7 +59,7 @@ from aesara.compile.builders import infer_shape
from aesara.compile.function import function
from aesara.compile.io import In, Out
from aesara.compile.mode import AddFeatureOptimizer, Mode, get_default_mode, get_mode
from aesara.compile.profiling import ScanProfileStats, register_profiler_printer
from aesara.compile.profiling import register_profiler_printer
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
from aesara.graph.basic import (
......@@ -77,7 +77,7 @@ from aesara.graph.op import HasInnerGraph, Op
from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op
from aesara.scan.utils import Validator, forced_replace, safe_new
from aesara.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import minimum
from aesara.tensor.shape import Shape_i
......
......@@ -11,6 +11,7 @@ import numpy as np
from aesara import scalar as aes
from aesara import tensor as at
from aesara.compile.profiling import ProfileStats
from aesara.configdefaults import config
from aesara.graph.basic import (
Constant,
......@@ -123,6 +124,65 @@ class until:
assert self.condition.ndim == 0
class ScanProfileStats(ProfileStats):
show_sum = False
callcount = 0.0
nbsteps = 0.0
call_time = 0.0
def __init__(self, atexit_print=True, name=None, **kwargs):
super().__init__(atexit_print, **kwargs)
self.name = name
def summary_globals(self, file):
# Do nothing, we don't want to print extra global summary
# here.
pass
def summary_function(self, file):
# RP: every time we compile a function a ProfileStats is created for
# that function. This means that every time a optimization replaces
# some scan op, some orphane ProfileStats remains in the air ..
# also even without any optimization, scan compiles a dummy function
# that will produce a ProfileStats that will correspond to a
# function that will never be called. Printing several empty
# Function profiling is just extremely confusing
if self.callcount == 0:
return
print("", file=file)
if self.name is not None:
print("Scan Op profiling (", self.name, ")", file=file)
else:
print("Scan Op profiling", file=file)
print("==================", file=file)
print(f" Message: {self.message}", file=file)
print(
(
f" Time in {self.callcount} calls of the op (for a total of {self.nbsteps} "
f"steps) {self.call_time:3}s"
),
file=file,
)
print("", file=file)
val = 0
if self.call_time > 0:
val = self.vm_call_time * 100 / self.call_time
print(
f" Total time spent in calling the VM {self.vm_call_time:e}s ({val:.3f}%)",
file=file,
)
val = 100
if self.call_time > 0:
val = 100.0 - self.vm_call_time * 100 / self.call_time
print(
f" Total overhead (computing slices..) {self.call_time - self.vm_call_time:e}s ({val:.3f}%)",
file=file,
)
print("", file=file)
def traverse(out, x, x_copy, d, visited=None):
"""
Function used by scan to parse the tree and figure out which nodes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论