提交 6c4a5398 authored 作者: Faruk Ahmed's avatar Faruk Ahmed

profile --> (profile or print_global_stats

上级 d3dcb876
...@@ -664,7 +664,7 @@ class Function(object): ...@@ -664,7 +664,7 @@ class Function(object):
input_storage = [i.value for i in ins] input_storage = [i.value for i in ins]
# reinitialize new maker and create new function # reinitialize new maker and create new function
if profile is None: if profile is None:
profile = config.profile profile = config.profile or config.print_global_stats
# profile -> True or False # profile -> True or False
if profile is True: if profile is True:
if name: if name:
......
...@@ -364,7 +364,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -364,7 +364,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
if givens is None: if givens is None:
givens = [] givens = []
if profile is None: if profile is None:
profile = config.profile profile = config.profile or config.print_global_stats
# profile -> True or False # profile -> True or False
if profile is False: if profile is False:
profile = None profile = None
......
...@@ -50,66 +50,67 @@ def _atexit_print_fn(): ...@@ -50,66 +50,67 @@ def _atexit_print_fn():
Print ProfileStat objects in _atexit_print_list to _atexit_print_file. Print ProfileStat objects in _atexit_print_list to _atexit_print_file.
""" """
to_sum = [] if config.profile:
to_sum = []
if config.profiling.destination == 'stderr': if config.profiling.destination == 'stderr':
destination_file = sys.stderr destination_file = sys.stderr
elif config.profiling.destination == 'stdout': elif config.profiling.destination == 'stdout':
destination_file = sys.stdout destination_file = sys.stdout
else:
destination_file = open(config.profiling.destination, 'w')
# Reverse sort in the order of compile+exec time
for ps in sorted(_atexit_print_list,
key=lambda a:a.compile_time + a.fct_call_time)[::-1]:
if ps.fct_callcount >= 1 or ps.compile_time > 1:
ps.summary(file=destination_file,
n_ops_to_print=config.profiling.n_ops,
n_apply_to_print=config.profiling.n_apply)
if not isinstance(ps, ScanProfileStats):
to_sum.append(ps)
else: else:
# TODO print the name if there is one! destination_file = open(config.profiling.destination, 'w')
print('Skipping empty Profile')
if len(to_sum) > 1: # Reverse sort in the order of compile+exec time
# Make a global profile for ps in sorted(_atexit_print_list,
cum = copy.copy(to_sum[0]) key=lambda a:a.compile_time + a.fct_call_time)[::-1]:
msg = ("Sum of all(%d) printed profiles at exit excluding Scan op" if ps.fct_callcount >= 1 or ps.compile_time > 1:
" profile." % len(to_sum)) ps.summary(file=destination_file,
cum.message = msg n_ops_to_print=config.profiling.n_ops,
for ps in to_sum[1:]: n_apply_to_print=config.profiling.n_apply)
for attr in ["compile_time", "fct_call_time", "fct_callcount", if not isinstance(ps, ScanProfileStats):
"vm_call_time", "optimizer_time", "linker_time", to_sum.append(ps)
"validate_time", "import_time",
"linker_node_make_thunks"]:
setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr))
# merge dictonary
for attr in ["apply_time", "apply_callcount",
"apply_cimpl", "variable_shape", "variable_strides",
"linker_make_thunk_time"]:
cum_attr = getattr(cum, attr)
for key, val in iteritems(getattr(ps, attr)):
assert key not in cum_attr
cum_attr[key] = val
if cum.optimizer_profile and ps.optimizer_profile:
try:
merge = cum.optimizer_profile[0].merge_profile(
cum.optimizer_profile[1],
ps.optimizer_profile[1])
assert len(merge) == len(cum.optimizer_profile[1])
cum.optimizer_profile = (cum.optimizer_profile[0], merge)
except Exception as e:
print("Got an exception while merging profile")
print(e)
cum.optimizer_profile = None
else: else:
cum.optimizer_profile = None # TODO print the name if there is one!
print('Skipping empty Profile')
if len(to_sum) > 1:
# Make a global profile
cum = copy.copy(to_sum[0])
msg = ("Sum of all(%d) printed profiles at exit excluding Scan op"
" profile." % len(to_sum))
cum.message = msg
for ps in to_sum[1:]:
for attr in ["compile_time", "fct_call_time", "fct_callcount",
"vm_call_time", "optimizer_time", "linker_time",
"validate_time", "import_time",
"linker_node_make_thunks"]:
setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr))
# merge dictonary
for attr in ["apply_time", "apply_callcount",
"apply_cimpl", "variable_shape", "variable_strides",
"linker_make_thunk_time"]:
cum_attr = getattr(cum, attr)
for key, val in iteritems(getattr(ps, attr)):
assert key not in cum_attr
cum_attr[key] = val
if cum.optimizer_profile and ps.optimizer_profile:
try:
merge = cum.optimizer_profile[0].merge_profile(
cum.optimizer_profile[1],
ps.optimizer_profile[1])
assert len(merge) == len(cum.optimizer_profile[1])
cum.optimizer_profile = (cum.optimizer_profile[0], merge)
except Exception as e:
print("Got an exception while merging profile")
print(e)
cum.optimizer_profile = None
else:
cum.optimizer_profile = None
cum.summary(file=destination_file, cum.summary(file=destination_file,
n_ops_to_print=config.profiling.n_ops, n_ops_to_print=config.profiling.n_ops,
n_apply_to_print=config.profiling.n_apply) n_apply_to_print=config.profiling.n_apply)
if config.print_global_stats: if config.print_global_stats:
print_global_stats() print_global_stats()
......
...@@ -482,7 +482,7 @@ class Stack(VM): ...@@ -482,7 +482,7 @@ class Stack(VM):
try: try:
_, dt = self.run_thunk_of_node(current_apply) _, dt = self.run_thunk_of_node(current_apply)
del _ del _
if config.profile: if config.profile or config.print_global_stats:
current_idx = self.node_idx[current_apply] current_idx = self.node_idx[current_apply]
self.call_counts[current_idx] += 1 self.call_counts[current_idx] += 1
self.call_times[current_idx] += dt self.call_times[current_idx] += dt
...@@ -596,7 +596,7 @@ class Stack(VM): ...@@ -596,7 +596,7 @@ class Stack(VM):
if current_apply.inputs[r].owner: if current_apply.inputs[r].owner:
apply_stack.append(current_apply.inputs[r].owner) apply_stack.append(current_apply.inputs[r].owner)
else: else:
if config.profile: if config.profile or config.print_global_stats:
for (idx, o) in enumerate(thunks[ for (idx, o) in enumerate(thunks[
self.node_idx[current_apply]].outputs): self.node_idx[current_apply]].outputs):
var = self.nodes[ var = self.nodes[
...@@ -757,7 +757,7 @@ class VM_Linker(link.LocalLinker): ...@@ -757,7 +757,7 @@ class VM_Linker(link.LocalLinker):
associated to self, else, a new VM_Linker associated to fgraph. associated to self, else, a new VM_Linker associated to fgraph.
""" """
if (config.profile and if ((config.profile or config.print_global_stats) and
((hasattr(theano, 'sandbox') and ((hasattr(theano, 'sandbox') and
hasattr(theano.sandbox, 'cuda') and hasattr(theano.sandbox, 'cuda') and
theano.sandbox.cuda.cuda_enabled) or theano.sandbox.cuda.cuda_enabled) or
...@@ -856,7 +856,7 @@ class VM_Linker(link.LocalLinker): ...@@ -856,7 +856,7 @@ class VM_Linker(link.LocalLinker):
pre_call_clear = [storage_map[v] for v in self.no_recycling] pre_call_clear = [storage_map[v] for v in self.no_recycling]
if (self.callback is not None or self.callback_input is not None or if (self.callback is not None or self.callback_input is not None or
(config.profile and config.profile_memory) or ((config.profile or config.print_global_stats) and config.profile_memory) or
(self.allow_partial_eval and not self.use_cloop)): (self.allow_partial_eval and not self.use_cloop)):
if self.use_cloop and (self.callback is not None or if self.use_cloop and (self.callback is not None or
...@@ -1086,7 +1086,7 @@ class VM_Linker(link.LocalLinker): ...@@ -1086,7 +1086,7 @@ class VM_Linker(link.LocalLinker):
lazy = config.vm.lazy lazy = config.vm.lazy
if lazy is None: if lazy is None:
lazy = not all([(not th.lazy) for th in thunks]) lazy = not all([(not th.lazy) for th in thunks])
if not (lazy or (config.profile and config.profile_memory) or if not (lazy or ((config.profile or config.print_global_stats) and config.profile_memory) or
self.use_cloop or self.callback or self.callback_input): self.use_cloop or self.callback or self.callback_input):
for pair in itervalues(reallocated_info): for pair in itervalues(reallocated_info):
storage_map[pair[1]] = storage_map[pair[0]] storage_map[pair[1]] = storage_map[pair[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论