提交 567acc71 authored 作者: Frederic Bastien's avatar Frederic Bastien

implemented the diff of flops.

上级 1fb1b270
...@@ -71,13 +71,19 @@ class ProfileMode(Mode): ...@@ -71,13 +71,19 @@ class ProfileMode(Mode):
op_call = self.op_call op_call = self.op_call
op_cimpl = self.op_cimpl op_cimpl = self.op_cimpl
self.print_summary_("print_summary",local_time, compile_time, apply_time, apply_call, op_time, op_call, op_cimpl, n_apply_to_print, n_ops_to_print) op_flops = {}
for a,t in op_time.items():
if hasattr(a,'flops'):
op_flops[a]=a.flops*op_call[a]/t/1e6
self.print_summary_("print_summary",local_time, compile_time,
apply_time, apply_call, op_time, op_call, op_cimpl,
op_flops, n_apply_to_print, n_ops_to_print)
def print_diff_summary(self, other, n_apply_to_print=15, n_ops_to_print=20): def print_diff_summary(self, other, n_apply_to_print=15, n_ops_to_print=20):
""" As print_summary, but print the absolute difference on two different profile mode. """ As print_summary, but print the absolute difference on two different profile mode.
TODO: Also we don't print the Apply-wise summary as it don't work for now. TODO: Also we don't print the Apply-wise summary as it don't work for now.
TODO: make flops the difference of flops
TODO: make comparaison with gpu code. TODO: make comparaison with gpu code.
param: other the other instance of ProfileMode that we want to be compared to. param: other the other instance of ProfileMode that we want to be compared to.
...@@ -87,22 +93,36 @@ class ProfileMode(Mode): ...@@ -87,22 +93,36 @@ class ProfileMode(Mode):
param: n_ops_to_print the number of ops to print. Default 20. param: n_ops_to_print the number of ops to print. Default 20.
""" """
def diff_dict(a,b_): def diff_dict(a_time,b_time_):
r = {} r = {}
b = copy.copy(b_) b_time = copy.copy(b_time_)
for a,t in a.items(): for a,ta in a_time.items():
r.setdefault(a,0) r.setdefault(a,0)
t2 = b.pop(a,0) tb = b_time.pop(a,0)
#print t,t2,abs(t-t2),a r[a]+=abs(ta-tb)
r[a]+=abs(t-t2)
#they are missing in a #they are missing in a
print "missing items",len(b) print "missing items",len(b_time)
for a,t in b.items(): for a,t in b_time.items():
r.setdefault(a,0) r.setdefault(a,0)
r[a]+=t r[a]+=t
return r return r
def diff_dict_flops(a_time,b_time_,a_call,b_call):
flops = {}
b_time = copy.copy(b_time_)
for a,ta in a_time.items():
tb = b_time.pop(a,0)
if hasattr(a,'flops'):
flops[a]=abs(a.flops*a_call[a]/ta - a.flops*b_call[a]/tb)/1e6
#they are missing in a
for b,tb in b_time.items():
if hasattr(b,'flops'):
flops[b]=b.flops*b_call[b]/tb/1e6
return flops
local_time = abs(self.local_time[0]-other.local_time[0]) local_time = abs(self.local_time[0]-other.local_time[0])
compile_time = abs(self.compile_time-other.compile_time) compile_time = abs(self.compile_time-other.compile_time)
apply_time = diff_dict(self.apply_time, other.apply_time) apply_time = diff_dict(self.apply_time, other.apply_time)
...@@ -110,12 +130,17 @@ class ProfileMode(Mode): ...@@ -110,12 +130,17 @@ class ProfileMode(Mode):
op_time = diff_dict(self.op_time, other.op_time) op_time = diff_dict(self.op_time, other.op_time)
op_call = diff_dict(self.op_call, other.op_call) op_call = diff_dict(self.op_call, other.op_call)
op_cimpl = self.op_cimpl and other.op_cimpl op_cimpl = self.op_cimpl and other.op_cimpl
self.print_summary_("print_diff_summary",local_time, compile_time, apply_time, apply_call, op_time, op_call, op_cimpl, n_apply_to_print, n_ops_to_print, print_apply=False) op_flops = diff_dict_flops(self.op_time, other.op_time, self.op_call, other.op_call)
self.print_summary_("print_diff_summary",local_time, compile_time,
apply_time, apply_call, op_time, op_call, op_cimpl,
op_flops, n_apply_to_print=n_apply_to_print,
n_ops_to_print=n_ops_to_print, print_apply=False)
@staticmethod @staticmethod
def print_summary_(fct_name, local_time, compile_time, apply_time, apply_call, op_time, op_call, op_cimpl, def print_summary_(fct_name, local_time, compile_time, apply_time, apply_call, op_time, op_call, op_cimpl,
n_apply_to_print=15, n_ops_to_print=20, print_apply=True): op_flops=None, n_apply_to_print=15, n_ops_to_print=20, print_apply=True):
""" """
do the actual printing of print_summary and print_diff_summary. do the actual printing of print_summary and print_diff_summary.
...@@ -145,15 +170,10 @@ class ProfileMode(Mode): ...@@ -145,15 +170,10 @@ class ProfileMode(Mode):
sum(f for f, t, a, nb_call in atimes[n_apply_to_print:])*100, sum(f for f, t, a, nb_call in atimes[n_apply_to_print:])*100,
sum(t for f, t, a, nb_call in atimes[n_apply_to_print:])) sum(t for f, t, a, nb_call in atimes[n_apply_to_print:]))
flops=False
flops_msg='' flops_msg=''
for a,t in op_time.items(): if op_flops:
if hasattr(a,'flops'): flops_msg=' <MFlops/s>'
flops=True print '\nHACK WARNING: we print the flops for some OP, but the logic don\' always work. You need to know the internal of Theano to make it work correctly. Otherwise don\'t use!'
flops_msg=' <MFlops/s>'
print '\nHACK WARNING: we print the flops for some OP, but the logic don\' always work. You need to know the internal of Theano to make it work correctly. Otherwise don\'t use!'
break
print '\nOp-wise summary: < of local_time spent on this kind of Op> <cumulative seconds> <self seconds>%s <nb_call> <Op name>'%(flops_msg) print '\nOp-wise summary: < of local_time spent on this kind of Op> <cumulative seconds> <self seconds>%s <nb_call> <Op name>'%(flops_msg)
otimes = [(t/local_time, t, a, op_cimpl[a], op_call[a]) for a, t in op_time.items()] otimes = [(t/local_time, t, a, op_cimpl[a], op_call[a]) for a, t in op_time.items()]
...@@ -166,11 +186,8 @@ class ProfileMode(Mode): ...@@ -166,11 +186,8 @@ class ProfileMode(Mode):
msg = '*' msg = '*'
else: else:
msg = ' ' msg = ' '
m=-1 if op_flops:
if hasattr(a,'flops'): print ' %4.1f%% %.3fs %.3fs %s %7.1f %d %s' % (f*100, tot, t, msg, op_flops.get(a,-1), nb_call, a)
m=a.flops*op_call[a]/t/1e6
if flops:
print ' %4.1f%% %.3fs %.3fs %s %7.1f %d %s' % (f*100, tot, t, msg, m, nb_call, a)
else: else:
print ' %4.1f%% %.3fs %.3fs %s %s' % (f*100, tot, t, msg, a) print ' %4.1f%% %.3fs %.3fs %s %s' % (f*100, tot, t, msg, a)
print ' ... (remaining %i Ops account for %6.2f%%(%.2fs) of the runtime)'\ print ' ... (remaining %i Ops account for %6.2f%%(%.2fs) of the runtime)'\
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论