提交 fc62c756 authored 作者: Frederic's avatar Frederic

Add profiling detail for the MergeOptimizer

上级 9950ce08
......@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
types that it can raise
"""
import sys
import time
import theano
from theano.gof import graph
......@@ -78,6 +79,8 @@ class FunctionGraph(utils.object2):
"""
self.execute_callbacks_time = 0
if features is None:
features = []
......@@ -521,6 +524,7 @@ class FunctionGraph(utils.object2):
getattr(feature, name)(*args)
for each feature which has a method called after name.
"""
t0 = time.time()
for feature in self._features:
try:
fn = getattr(feature, name)
......@@ -531,6 +535,7 @@ class FunctionGraph(utils.object2):
continue
fn(self, *args, **kwargs)
self.execute_callbacks_time += time.time() - t0
def collect_callbacks(self, name, *args):
"""WRITEME
......
......@@ -539,6 +539,13 @@ class MergeOptimizer(Optimizer):
# Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way.
sched = fgraph.merge_feature.scheduled
nb_fail = 0
t0 = time.time()
if fgraph.profile:
validate_before = fgraph.profile.validate_time
callback_before = fgraph.execute_callbacks_time
nb_merged = 0
nb_constant = 0
while sched:
pairs_list = sched.pop()
success = True
......@@ -547,17 +554,44 @@ class MergeOptimizer(Optimizer):
fgraph.replace_all_validate(pairs, 'Merge')
except InconsistencyError:
success = False
nb_fail += 1
fgraph.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
if success:
nb_merged += len(pairs)
if isinstance(pairs[0][0], graph.Constant):
nb_constant += 1
#print pairs, pairs[0][0].type
break
if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
callback_time = fgraph.execute_callbacks_time - callback_before
else:
validate_time = None
callback_time = None
# clear blacklist
fgraph.merge_feature.blacklist = []
return (nb_fail, time.time() - t0, validate_time,
callback_time, nb_merged, nb_constant)
def __str__(self):
return self.__class__.__name__
@staticmethod
def print_profile(stream, prof, level=0):
nb_fail, replace_time, validate_time, callback_time, nb_merged, nb_constant = prof
blanc = (' ' * level)
print >> stream, blanc, "MergeOptimizer"
print >> stream, blanc, " nb_fail", nb_fail
print >> stream, blanc, " replace_time", replace_time
print >> stream, blanc, " validate_time", validate_time
print >> stream, blanc, " callback_time", callback_time
print >> stream, blanc, " nb_merged", nb_merged
print >> stream, blanc, " nb_constant", nb_constant
merge_optimizer = MergeOptimizer()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论