提交 7fd3fd68 authored 作者: Frederic's avatar Frederic

more opt profiling info.

上级 be62dd96
...@@ -1500,7 +1500,13 @@ class GemmOptimizer(Optimizer): ...@@ -1500,7 +1500,13 @@ class GemmOptimizer(Optimizer):
time_factor_can = 0 time_factor_can = 0
time_factor_list = 0 time_factor_list = 0
time_toposort = 0 time_toposort = 0
if fgraph.profile:
validate_before = fgraph.profile.validate_time
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
while did_something: while did_something:
nb_iter += 1
t0 = time.time() t0 = time.time()
nodelist = theano.gof.graph.io_toposort(fgraph.inputs, fgraph.outputs) nodelist = theano.gof.graph.io_toposort(fgraph.inputs, fgraph.outputs)
time_toposort += time.time() - t0 time_toposort += time.time() - t0
...@@ -1545,16 +1551,29 @@ class GemmOptimizer(Optimizer): ...@@ -1545,16 +1551,29 @@ class GemmOptimizer(Optimizer):
except ReplacementDidntRemovedError, e: except ReplacementDidntRemovedError, e:
nb_replacement_didn_t_remove += 1 nb_replacement_didn_t_remove += 1
self.warned = True self.warned = True
nb_iter += 1 if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
callback_time = fgraph.execute_callbacks_time - callback_before
callbacks_time = {}
for k, v in fgraph.execute_callbacks_times.iteritems():
if k in callbacks_before:
callbacks_time[k] = v - callbacks_before[k]
else:
callbacks_time[k] = v
else:
validate_time = None
callback_time = None
callbacks_time = {}
return (self, nb_iter, nb_replacement, nb_replacement_didn_t_remove, return (self, nb_iter, nb_replacement, nb_replacement_didn_t_remove,
nb_inconsistency_make, nb_inconsistency_replace, nb_inconsistency_make, nb_inconsistency_replace,
time_canonicalize, time_factor_can, time_canonicalize, time_factor_can,
time_factor_list, time_toposort) time_factor_list, time_toposort,
validate_time, callback_time, callbacks_time,)
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
blanc = (' ' * level) blanc = (' ' * level)
#1946.912556s - ('gemm_optimizer', 'GemmOptimizer', 1)
print >> stream, blanc, "GemmOptimizer" print >> stream, blanc, "GemmOptimizer"
print >> stream, blanc, " nb_iter", prof[1] print >> stream, blanc, " nb_iter", prof[1]
print >> stream, blanc, " nb_replacement", prof[2] print >> stream, blanc, " nb_replacement", prof[2]
...@@ -1565,6 +1584,12 @@ class GemmOptimizer(Optimizer): ...@@ -1565,6 +1584,12 @@ class GemmOptimizer(Optimizer):
print >> stream, blanc, " time_factor_can", prof[7] print >> stream, blanc, " time_factor_can", prof[7]
print >> stream, blanc, " time_factor_list", prof[8] print >> stream, blanc, " time_factor_list", prof[8]
print >> stream, blanc, " time_toposort", prof[9] print >> stream, blanc, " time_toposort", prof[9]
print >> stream, blanc, " validate_time", prof[10]
print >> stream, blanc, " callback_time", prof[11]
print >> stream, blanc, " callbacks_time"
for i in sorted(prof[12].iteritems(), key=lambda a: a[1]):
if i[1] > 0:
print i
class Dot22(GemmRelated): class Dot22(GemmRelated):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论