提交 fce0b82f authored 作者: Frederic Bastien's avatar Frederic Bastien

Add optimizer profiling information.

上级 c657bad8
......@@ -225,12 +225,16 @@ class SeqOptimizer(Optimizer, list):
callback_before = fgraph.execute_callbacks_time
nb_node_before = len(fgraph.apply_nodes)
sub_profs = []
nb_nodes = []
for optimizer in self:
try:
nb_nodes_before = len(fgraph.apply_nodes)
t0 = time.time()
sub_prof = optimizer.optimize(fgraph)
l.append(float(time.time() - t0))
sub_profs.append(sub_prof)
nb_nodes.append((nb_nodes_before,
len(fgraph.apply_nodes)))
if fgraph.profile:
sub_validate_time.append(fgraph.profile.validate_time)
except AssertionError:
......@@ -249,7 +253,8 @@ class SeqOptimizer(Optimizer, list):
validate_time = None
callback_time = fgraph.execute_callbacks_time - callback_before
return (self, l, validate_time, callback_time, nb_node_before,
len(fgraph.apply_nodes), sub_profs, sub_validate_time)
len(fgraph.apply_nodes), sub_profs, sub_validate_time,
nb_nodes)
def __str__(self):
return "SeqOpt(%s)" % list.__str__(self)
......@@ -270,7 +275,7 @@ class SeqOptimizer(Optimizer, list):
@staticmethod
def print_profile(stream, prof, level=0):
(opts, prof, validate_time, callback_time, nb_node_before,
nb_node_after, sub_profs, sub_validate_time) = prof
nb_node_after, sub_profs, sub_validate_time, nb_nodes) = prof
blanc = (' ' * level)
print(blanc, "SeqOptimizer", end=' ', file=stream)
......@@ -284,18 +289,19 @@ class SeqOptimizer(Optimizer, list):
print(blanc, " %.3fs for callback" % (callback_time), file=stream)
print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream)
if level == 0:
print(blanc, " time - (name, class, index) - validate time", file=stream)
print(blanc, " time - (name, class, index, nodes before, nodes after) - validate time", file=stream)
ll = []
for opt in opts:
if hasattr(opt, "__name__"):
ll.append((opt.__name__, opt.__class__.__name__,
opts.index(opt)))
name = opt.__name__
else:
ll.append((opt.name, opt.__class__.__name__,
opts.index(opt)))
lll = sorted(zip(prof, ll), key=lambda a: a[0])
name = opt.name
idx = opts.index(opt)
ll.append((name, opt.__class__.__name__,
idx) + nb_nodes[idx])
lll = sorted(zip(prof, ll, nb_nodes), key=lambda a: a[0])
for (t, opt) in lll[::-1]:
for (t, opt, nb_n) in lll[::-1]:
# if t < 1:
# continue
if sub_validate_time:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论