提交 af069dbb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move aesara.compile.profiling.ScanProfileStats to aesara.scan.utils

上级 dee152ce
...@@ -59,5 +59,5 @@ from aesara.compile.ops import ( ...@@ -59,5 +59,5 @@ from aesara.compile.ops import (
register_view_op_c_code, register_view_op_c_code,
view_op, view_op,
) )
from aesara.compile.profiling import ProfileStats, ScanProfileStats from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared, shared_constructor from aesara.compile.sharedvalue import SharedVariable, shared, shared_constructor
...@@ -26,11 +26,6 @@ from aesara.configdefaults import config ...@@ -26,11 +26,6 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
__authors__ = "James Bergstra " "PyMC Developers " "Aesara Developers "
__copyright__ = "(c) 2011, Universite de Montreal"
__docformat__ = "restructuredtext en"
logger = logging.getLogger("aesara.compile.profiling") logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time() aesara_imported_time = time.time()
...@@ -71,7 +66,8 @@ def _atexit_print_fn(): ...@@ -71,7 +66,8 @@ def _atexit_print_fn():
n_ops_to_print=config.profiling__n_ops, n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply, n_apply_to_print=config.profiling__n_apply,
) )
if not isinstance(ps, ScanProfileStats):
if ps.show_sum:
to_sum.append(ps) to_sum.append(ps)
else: else:
# TODO print the name if there is one! # TODO print the name if there is one!
...@@ -79,10 +75,7 @@ def _atexit_print_fn(): ...@@ -79,10 +75,7 @@ def _atexit_print_fn():
if len(to_sum) > 1: if len(to_sum) > 1:
# Make a global profile # Make a global profile
cum = copy.copy(to_sum[0]) cum = copy.copy(to_sum[0])
msg = ( msg = f"Sum of all({len(to_sum)}) printed profiles at exit."
f"Sum of all({len(to_sum)}) printed profiles at exit excluding Scan op"
" profile."
)
cum.message = msg cum.message = msg
for ps in to_sum[1:]: for ps in to_sum[1:]:
for attr in [ for attr in [
...@@ -209,6 +202,7 @@ class ProfileStats: ...@@ -209,6 +202,7 @@ class ProfileStats:
# documented and initialized together. # documented and initialized together.
# dictionary variables are initialized with None. # dictionary variables are initialized with None.
# #
show_sum: bool = True
compile_time = 0.0 compile_time = 0.0
# Total time spent in body of orig_function, # Total time spent in body of orig_function,
...@@ -1729,61 +1723,3 @@ class ProfileStats: ...@@ -1729,61 +1723,3 @@ class ProfileStats:
] ]
for f in _profiler_printers: for f in _profiler_printers:
f(*params, file=file) f(*params, file=file)
class ScanProfileStats(ProfileStats):
callcount = 0.0
nbsteps = 0.0
call_time = 0.0
def __init__(self, atexit_print=True, name=None, **kwargs):
super().__init__(atexit_print, **kwargs)
self.name = name
def summary_globals(self, file):
# Do nothing, we don't want to print extra global summary
# here.
pass
def summary_function(self, file):
# RP: every time we compile a function a ProfileStats is created for
# that function. This means that every time a optimization replaces
# some scan op, some orphane ProfileStats remains in the air ..
# also even without any optimization, scan compiles a dummy function
# that will produce a ProfileStats that will correspond to a
# function that will never be called. Printing several empty
# Function profiling is just extremely confusing
if self.callcount == 0:
return
print("", file=file)
if self.name is not None:
print("Scan Op profiling (", self.name, ")", file=file)
else:
print("Scan Op profiling", file=file)
print("==================", file=file)
print(f" Message: {self.message}", file=file)
print(
(
f" Time in {self.callcount} calls of the op (for a total of {self.nbsteps} "
f"steps) {self.call_time:3}s"
),
file=file,
)
print("", file=file)
val = 0
if self.call_time > 0:
val = self.vm_call_time * 100 / self.call_time
print(
f" Total time spent in calling the VM {self.vm_call_time:e}s ({val:.3f}%)",
file=file,
)
val = 100
if self.call_time > 0:
val = 100.0 - self.vm_call_time * 100 / self.call_time
print(
f" Total overhead (computing slices..) {self.call_time - self.vm_call_time:e}s ({val:.3f}%)",
file=file,
)
print("", file=file)
...@@ -59,7 +59,7 @@ from aesara.compile.builders import infer_shape ...@@ -59,7 +59,7 @@ from aesara.compile.builders import infer_shape
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
from aesara.compile.mode import AddFeatureOptimizer, Mode, get_default_mode, get_mode from aesara.compile.mode import AddFeatureOptimizer, Mode, get_default_mode, get_mode
from aesara.compile.profiling import ScanProfileStats, register_profiler_printer from aesara.compile.profiling import register_profiler_printer
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
from aesara.graph.basic import ( from aesara.graph.basic import (
...@@ -77,7 +77,7 @@ from aesara.graph.op import HasInnerGraph, Op ...@@ -77,7 +77,7 @@ from aesara.graph.op import HasInnerGraph, Op
from aesara.link.c.basic import CLinker from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op from aesara.link.utils import raise_with_op
from aesara.scan.utils import Validator, forced_replace, safe_new from aesara.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import minimum from aesara.tensor.math import minimum
from aesara.tensor.shape import Shape_i from aesara.tensor.shape import Shape_i
......
...@@ -11,6 +11,7 @@ import numpy as np ...@@ -11,6 +11,7 @@ import numpy as np
from aesara import scalar as aes from aesara import scalar as aes
from aesara import tensor as at from aesara import tensor as at
from aesara.compile.profiling import ProfileStats
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import ( from aesara.graph.basic import (
Constant, Constant,
...@@ -123,6 +124,65 @@ class until: ...@@ -123,6 +124,65 @@ class until:
assert self.condition.ndim == 0 assert self.condition.ndim == 0
class ScanProfileStats(ProfileStats):
show_sum = False
callcount = 0.0
nbsteps = 0.0
call_time = 0.0
def __init__(self, atexit_print=True, name=None, **kwargs):
super().__init__(atexit_print, **kwargs)
self.name = name
def summary_globals(self, file):
# Do nothing, we don't want to print extra global summary
# here.
pass
def summary_function(self, file):
# RP: every time we compile a function a ProfileStats is created for
# that function. This means that every time a optimization replaces
# some scan op, some orphane ProfileStats remains in the air ..
# also even without any optimization, scan compiles a dummy function
# that will produce a ProfileStats that will correspond to a
# function that will never be called. Printing several empty
# Function profiling is just extremely confusing
if self.callcount == 0:
return
print("", file=file)
if self.name is not None:
print("Scan Op profiling (", self.name, ")", file=file)
else:
print("Scan Op profiling", file=file)
print("==================", file=file)
print(f" Message: {self.message}", file=file)
print(
(
f" Time in {self.callcount} calls of the op (for a total of {self.nbsteps} "
f"steps) {self.call_time:3}s"
),
file=file,
)
print("", file=file)
val = 0
if self.call_time > 0:
val = self.vm_call_time * 100 / self.call_time
print(
f" Total time spent in calling the VM {self.vm_call_time:e}s ({val:.3f}%)",
file=file,
)
val = 100
if self.call_time > 0:
val = 100.0 - self.vm_call_time * 100 / self.call_time
print(
f" Total overhead (computing slices..) {self.call_time - self.vm_call_time:e}s ({val:.3f}%)",
file=file,
)
print("", file=file)
def traverse(out, x, x_copy, d, visited=None): def traverse(out, x, x_copy, d, visited=None):
""" """
Function used by scan to parse the tree and figure out which nodes Function used by scan to parse the tree and figure out which nodes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论