提交 bc0aee06 authored 作者: Frederic Bastien's avatar Frederic Bastien

create a hook system for the profiler and use it for scan.

上级 f1be580c
import time, atexit, copy
import theano
from theano.gof.link import WrapLinker
from theano.gof.cutils import run_cthunk
from theano.compile.mode import Mode, register_mode, predefined_modes, predefined_linkers, predefined_optimizers
......@@ -396,18 +395,10 @@ class ProfileMode(Mode):
%(max(0, len(atimes)-n_apply_to_print),
sum(f for f, t, a, nb_call in atimes[n_apply_to_print:]),
sum(t for f, t, a, nb_call in atimes[n_apply_to_print:]))
# Scan overhead profile
import theano # Why we need to re-import theano here? Otherwise is crash
if any([isinstance(node.op, (theano.Scan, theano.ScanGrad)) for (_,node) in apply_time.keys()]):
print
print 'Scan overhead:'
print '<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(% scan op time)> <sub scan op time(% scan op time)> <node>'
for (_,node),v in apply_time.items():
if isinstance(node.op, (theano.Scan, theano.ScanGrad)):
scan_fct_time = sum(node.op.mode_instance.fct_call_time.values())
scan_op_time = sum(node.op.mode_instance.local_time)
print ' %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%'%(v, scan_fct_time, scan_op_time, scan_fct_time/v*100, scan_op_time/v*100), node
for printer in profiler_printers:
printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time, op_cimpl, message, outputs_size,
other_time)
if any([x[2].__name__.lower().startswith("gpu") for x in sotimes]):
print
......@@ -626,3 +617,10 @@ def atexit_print_default_profile_mode():
#Register atexit_print_default_profile_mode to have the summary of the
#predefined mode PROFILE_MODE if it is used printed when the program terminate.
atexit.register(atexit_print_default_profile_mode)
# Here we define an hook that allow to print extra profiling information
profiler_printers = []
def register_profiler_printer(fct):
profiler_printers.append(fct)
return fct
......@@ -1971,3 +1971,27 @@ optdb.register('scanOp_make_inplace', opt.in2out(scan_make_inplace,
@theano.compile.profilemode.register_profiler_printer
def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time, op_cimpl, message, outputs_size,
other_time):
# Scan overhead profile
if any([isinstance(node.op, (Scan, ScanGrad)) for (_,node) in apply_time.keys()]):
print
print 'Scan overhead:'
print '<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(% scan op time)> <sub scan op time(% scan op time)> <node>'
total_super_scan_time = 0
total_scan_fct_time = 0
total_scan_op_time = 0
for (_,node),v in apply_time.items():
if isinstance(node.op, (Scan, ScanGrad)):
scan_fct_time = sum(node.op.mode_instance.fct_call_time.values())
scan_op_time = sum(node.op.mode_instance.local_time)
total_super_scan_time += v
total_scan_fct_time += scan_fct_time
total_scan_op_time += scan_op_time
print ' %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%'%(
v, scan_fct_time, scan_op_time, scan_fct_time/v*100,
scan_op_time/v*100), node
print ' total %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%'%(
total_super_scan_time, total_scan_fct_time, total_scan_op_time, total_scan_fct_time/total_super_scan_time*100, total_scan_op_time/total_super_scan_time*100)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论