提交 12e9dd7b authored 作者: Frederic Bastien's avatar Frederic Bastien

lower profiler overhead by moving computation into the printing phase instead of…

lower profiler overhead by moving computation into the printing phase instead of the recording phase.
上级 5a0e9943
......@@ -25,9 +25,7 @@ class Profile_Maker(FunctionMaker):
for i, node in enumerate(ret.maker.env.toposort()):
self.mode.apply_time[(i,node.op)]=0.0
self.mode.apply_call[(i,node.op)]=0
self.mode.op_time[node.op]=0.
# self.mode.op_cimpl[node.op] =
self.mode.op_call[node.op] = 0
return ret
......@@ -36,16 +34,14 @@ class ProfileMode(Mode):
local_time = [0.0]
apply_time = {}
apply_call = {}
op_time = {}
op_cimpl = {}
op_call = {}
compile_time = 0 #time passed in theano.function()
fct_call_time = {}#time passed inside theano fct call including op time.
fct_call = {}
self.__setstate__((linker, optimizer, local_time,
apply_time, apply_call,
op_time, op_cimpl, op_call,
op_cimpl,
compile_time, fct_call_time, fct_call))
def function_maker(self, i,o,m, *args, **kwargs):
......@@ -58,19 +54,17 @@ class ProfileMode(Mode):
#print "__getstate__",self.provided_linker,self.provided_optimizer
return (self.provided_linker, self.provided_optimizer, self.local_time,
self.apply_time, self.apply_call,
self.op_time, self.op_cimpl, self.op_call, self.compile_time, self.fct_call_time, self.fct_call)
self.op_cimpl, self.compile_time, self.fct_call_time, self.fct_call)
def __setstate__(self, (linker, optimizer, local_time,
apply_time, apply_call,
op_time, op_cimpl, op_call,
op_cimpl,
compile_time, fct_call_time, fct_call)):
self.local_time = local_time
self.apply_time = apply_time
self.apply_call = apply_call
self.op_time = op_time
self.op_cimpl = op_cimpl
self.op_call = op_call
self.compile_time = compile_time
self.fct_call_time = fct_call_time
self.fct_call = fct_call
......@@ -94,9 +88,7 @@ class ProfileMode(Mode):
local_time[0] += dt
apply_time[(i,node.op)] += dt
apply_call[(i,node.op)] += 1
op_time[node.op] += dt
op_cimpl[node.op] = hasattr(th, 'cthunk')
op_call[node.op] += 1
self.provided_linker = linker
......@@ -133,18 +125,11 @@ class ProfileMode(Mode):
fct_call = self.fct_call
apply_time = self.apply_time
apply_call = self.apply_call
op_time = self.op_time
op_call = self.op_call
op_cimpl = self.op_cimpl
op_flops = {}
for a,t in op_time.items():
if hasattr(a,'flops'):
op_flops[a]=a.flops*op_call[a]/t/1e6
self.print_summary_("print_summary",local_time, compile_time, fct_call_time, fct_call,
apply_time, apply_call, op_time, op_call, op_cimpl,
op_flops, n_apply_to_print, n_ops_to_print)
apply_time, apply_call, op_cimpl,
n_apply_to_print, n_ops_to_print)
def print_diff_summary(self, other, n_apply_to_print=15, n_ops_to_print=20):
......@@ -173,42 +158,23 @@ class ProfileMode(Mode):
r[a]+=t
return r
def diff_dict_flops(a_time,b_time_,a_call,b_call):
flops = {}
b_time = copy.copy(b_time_)
for a,ta in a_time.items():
tb = b_time.pop(a,0)
if hasattr(a,'flops'):
flops[a]=a.flops*a_call[a]/ta - a.flops*b_call[a]/tb/1e6
#they are missing in a
for b,tb in b_time.items():
if hasattr(b,'flops'):
flops[b]=b.flops*b_call[b]/tb/1e6
return flops
local_time = self.local_time[0]-other.local_time[0]
compile_time = self.compile_time-other.compile_time
fct_call_time = diff_dict(self.fct_call_time,other.fct_call_time)
fct_call = diff_dict(self.fct_call,other.fct_call)
apply_time = diff_dict(self.apply_time, other.apply_time)
apply_call = diff_dict(self.apply_call, other.apply_call)
op_time = diff_dict(self.op_time, other.op_time)
op_call = diff_dict(self.op_call, other.op_call)
op_cimpl = self.op_cimpl and other.op_cimpl
op_flops = diff_dict_flops(self.op_time, other.op_time, self.op_call, other.op_call)
self.print_summary_("print_diff_summary",local_time, compile_time, fct_call_time, fct_call,
apply_time, apply_call, op_time, op_call, op_cimpl,
op_flops, n_apply_to_print=n_apply_to_print,
apply_time, apply_call, op_cimpl,
n_apply_to_print=n_apply_to_print,
n_ops_to_print=n_ops_to_print, print_apply=False)
@staticmethod
def print_summary_(fct_name, local_time, compile_time, fct_call_time, fct_call,
apply_time, apply_call, op_time, op_call, op_cimpl,
op_flops=None, n_apply_to_print=15, n_ops_to_print=20, print_apply=True):
apply_time, apply_call, op_cimpl,
n_apply_to_print=15, n_ops_to_print=20, print_apply=True):
"""
do the actual printing of print_summary and print_diff_summary.
......@@ -238,6 +204,19 @@ class ProfileMode(Mode):
sum(f for f, t, a, nb_call in atimes[n_apply_to_print:])*100,
sum(t for f, t, a, nb_call in atimes[n_apply_to_print:]))
op_time = {}
op_call = {}
for (i,a),t in apply_time.items():
op_time.setdefault(a,0)
op_call.setdefault(a,0)
op_time[a]+=t
op_call[a]+=apply_call[(i,a)]
op_flops = {}
for a,t in op_time.items():
if hasattr(a,'flops'):
op_flops[a]=a.flops*op_call[a]/t/1e6
flops_msg=''
if op_flops:
flops_msg=' <MFlops/s>'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论