提交 fc62c756 authored 作者: Frederic's avatar Frederic

Add profiling detail for the MergeOptimizer

上级 9950ce08
...@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception ...@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
types that it can raise types that it can raise
""" """
import sys import sys
import time
import theano import theano
from theano.gof import graph from theano.gof import graph
...@@ -78,6 +79,8 @@ class FunctionGraph(utils.object2): ...@@ -78,6 +79,8 @@ class FunctionGraph(utils.object2):
""" """
self.execute_callbacks_time = 0
if features is None: if features is None:
features = [] features = []
...@@ -521,6 +524,7 @@ class FunctionGraph(utils.object2): ...@@ -521,6 +524,7 @@ class FunctionGraph(utils.object2):
getattr(feature, name)(*args) getattr(feature, name)(*args)
for each feature which has a method called after name. for each feature which has a method called after name.
""" """
t0 = time.time()
for feature in self._features: for feature in self._features:
try: try:
fn = getattr(feature, name) fn = getattr(feature, name)
...@@ -531,6 +535,7 @@ class FunctionGraph(utils.object2): ...@@ -531,6 +535,7 @@ class FunctionGraph(utils.object2):
continue continue
fn(self, *args, **kwargs) fn(self, *args, **kwargs)
self.execute_callbacks_time += time.time() - t0
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
"""WRITEME """WRITEME
......
...@@ -539,6 +539,13 @@ class MergeOptimizer(Optimizer): ...@@ -539,6 +539,13 @@ class MergeOptimizer(Optimizer):
# Constant and non-constant are now applied in the same phase. # Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way. # I am not sure why, but it seems to be faster this way.
sched = fgraph.merge_feature.scheduled sched = fgraph.merge_feature.scheduled
nb_fail = 0
t0 = time.time()
if fgraph.profile:
validate_before = fgraph.profile.validate_time
callback_before = fgraph.execute_callbacks_time
nb_merged = 0
nb_constant = 0
while sched: while sched:
pairs_list = sched.pop() pairs_list = sched.pop()
success = True success = True
...@@ -547,17 +554,44 @@ class MergeOptimizer(Optimizer): ...@@ -547,17 +554,44 @@ class MergeOptimizer(Optimizer):
fgraph.replace_all_validate(pairs, 'Merge') fgraph.replace_all_validate(pairs, 'Merge')
except InconsistencyError: except InconsistencyError:
success = False success = False
nb_fail += 1
fgraph.merge_feature.blacklist.append( fgraph.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner)) (pairs[0][0].owner, pairs[0][1].owner))
if success: if success:
nb_merged += len(pairs)
if isinstance(pairs[0][0], graph.Constant):
nb_constant += 1
#print pairs, pairs[0][0].type
break break
if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
callback_time = fgraph.execute_callbacks_time - callback_before
else:
validate_time = None
callback_time = None
# clear blacklist # clear blacklist
fgraph.merge_feature.blacklist = [] fgraph.merge_feature.blacklist = []
return (nb_fail, time.time() - t0, validate_time,
callback_time, nb_merged, nb_constant)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
@staticmethod
def print_profile(stream, prof, level=0):
nb_fail, replace_time, validate_time, callback_time, nb_merged, nb_constant = prof
blanc = (' ' * level)
print >> stream, blanc, "MergeOptimizer"
print >> stream, blanc, " nb_fail", nb_fail
print >> stream, blanc, " replace_time", replace_time
print >> stream, blanc, " validate_time", validate_time
print >> stream, blanc, " callback_time", callback_time
print >> stream, blanc, " nb_merged", nb_merged
print >> stream, blanc, " nb_constant", nb_constant
merge_optimizer = MergeOptimizer() merge_optimizer = MergeOptimizer()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论