提交 fd29e66f authored 作者: Frederic Bastien's avatar Frederic Bastien

add a theano flags to print the time inside each optimizer into a SeqOptimizer

上级 f552b784
...@@ -4,7 +4,7 @@ amount of useful generic optimization tools. ...@@ -4,7 +4,7 @@ amount of useful generic optimization tools.
""" """
import sys, logging import sys, logging, time
import graph import graph
from env import InconsistencyError from env import InconsistencyError
import utils import utils
...@@ -13,12 +13,17 @@ import toolbox ...@@ -13,12 +13,17 @@ import toolbox
import op import op
from copy import copy from copy import copy
from theano.gof.python25 import any, all from theano.gof.python25 import any, all
from theano.configparser import AddConfigVar, BoolParam, config
#if sys.version_info[:2] >= (2,5): #if sys.version_info[:2] >= (2,5):
# from collections import defaultdict # from collections import defaultdict
_logger = logging.getLogger('theano.gof.opt') _logger = logging.getLogger('theano.gof.opt')
AddConfigVar('time_seq_optimizer',
"Should SeqOptimizer print the time taked by each of its optimizer",
BoolParam(False))
from theano.gof import deque from theano.gof import deque
import destroyhandler as dh import destroyhandler as dh
import traceback import traceback
...@@ -126,9 +131,13 @@ class SeqOptimizer(Optimizer, list): ...@@ -126,9 +131,13 @@ class SeqOptimizer(Optimizer, list):
"""WRITEME """WRITEME
Applies each L{Optimizer} in self in turn. Applies each L{Optimizer} in self in turn.
""" """
l=[]
nb_node_before = len(env.nodes)
for optimizer in self: for optimizer in self:
try: try:
t0=time.time()
optimizer.optimize(env) optimizer.optimize(env)
l.append(float(time.time()-t0))
except AssertionError: # do not catch Assertion failures except AssertionError: # do not catch Assertion failures
raise raise
except Exception, e: except Exception, e:
...@@ -137,6 +146,20 @@ class SeqOptimizer(Optimizer, list): ...@@ -137,6 +146,20 @@ class SeqOptimizer(Optimizer, list):
continue continue
else: else:
raise raise
if config.time_seq_optimizer:
print "SeqOptimizer",
if hasattr(self,"name"): print self.name,
elif hasattr(self,"__name__"): print self.__name__,
print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(env.nodes))
ll = [(opt.__name__ if hasattr(opt,"__name__")else opt.name,opt.__class__.__name__) for opt in self]
lll=zip(l,ll)
def cmp(a,b):
if a[0]==b[0]: return 0
if a[0]<b[0]: return -1
return 1
lll.sort(cmp)
print lll
def __eq__(self, other): def __eq__(self, other):
#added to override the list's __eq__ implementation #added to override the list's __eq__ implementation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论