提交 0530834d authored 作者: Frederic Bastien's avatar Frederic Bastien

replaced the theano.gof.utils.get_theano_flags fct by a class…

replaced the theano.gof.utils.get_theano_flags fct by a class theano.gof.utils.[TheanoConfig,config] that is a variante of ConfigParse.
上级 b1038e91
...@@ -5,7 +5,7 @@ from theano.gof.cutils import run_cthunk ...@@ -5,7 +5,7 @@ from theano.gof.cutils import run_cthunk
from theano.compile.mode import Mode, register_mode, predefined_modes, predefined_linkers, predefined_optimizers, default_linker, default_optimizer from theano.compile.mode import Mode, register_mode, predefined_modes, predefined_linkers, predefined_optimizers, default_linker, default_optimizer
from theano.gof.cc import OpWiseCLinker from theano.gof.cc import OpWiseCLinker
from theano import gof from theano import gof
from theano.gof.utils import get_theano_flag from theano.gof.utils import config
class ProfileMode(Mode): class ProfileMode(Mode):
def __init__(self, linker=default_linker, optimizer=default_optimizer): def __init__(self, linker=default_linker, optimizer=default_optimizer):
...@@ -73,7 +73,7 @@ class ProfileMode(Mode): ...@@ -73,7 +73,7 @@ class ProfileMode(Mode):
optimizer = predefined_optimizers[optimizer] optimizer = predefined_optimizers[optimizer]
self._optimizer = optimizer self._optimizer = optimizer
def print_summary(self, n_apply_to_print=15, n_ops_to_print=20): def print_summary(self, n_apply_to_print=None, n_ops_to_print=None):
""" Print 3 summary that show where the time is spend. The first show an Apply-wise summary, the second show an Op-wise summary, the third show an type-Op-wise summary. """ Print 3 summary that show where the time is spend. The first show an Apply-wise summary, the second show an Op-wise summary, the third show an type-Op-wise summary.
The Apply-wise summary print the timing information for the worst offending Apply nodes. This corresponds to individual Op applications within your graph which take the longest to execute (so if you use dot twice, you will see two entries there). The Apply-wise summary print the timing information for the worst offending Apply nodes. This corresponds to individual Op applications within your graph which take the longest to execute (so if you use dot twice, you will see two entries there).
...@@ -87,8 +87,8 @@ class ProfileMode(Mode): ...@@ -87,8 +87,8 @@ class ProfileMode(Mode):
:param n_ops_to_print: the number of ops to print. Default 20, or n_apply_to_print flag. :param n_ops_to_print: the number of ops to print. Default 20, or n_apply_to_print flag.
""" """
n_ops_to_print=int(get_theano_flag("n_ops_to_print", n_ops_to_print)) n_apply_to_print=config.getint("n_apply_to_print", n_apply_to_print)
n_apply_to_print=int(get_theano_flag("n_apply_to_print", n_apply_to_print)) n_ops_to_print=config.getint("n_ops_to_print", n_ops_to_print)
local_time = self.local_time[0] local_time = self.local_time[0]
compile_time = self.compile_time compile_time = self.compile_time
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# import variable # import variable
import re, os import re, os
import ConfigParser
def hashgen(): def hashgen():
hashgen.next += 1 hashgen.next += 1
...@@ -369,31 +370,103 @@ def type_guard(type1): ...@@ -369,31 +370,103 @@ def type_guard(type1):
return new_f return new_f
return wrap return wrap
default_={
'ProfileMode.n_apply_to_print':15,
'ProfileMode.n_ops_to_print':20,
'tensor_opt.local_elemwise_fusion':False,
'scalar_basic.amdlibm':False,
}
def get_theano_flag(key, default=None):
"""Return the value for a key passed via the THEANO_FLAGS environment variable.
:type key: a string class TheanoConfig(object):
:param key: the key to lookup """Return the value for a key after parsing ~/.theano.cfg and
:type default: any the THEANO_FLAGS environment variable.
:param default: the value to be returned if the key is not present. (Default: None)
if the variable don't exist return `default` We parse in that order the value to have:
if the key is not in the variable return `default` 1)the pair 'section.option':value in default_
if the key is in the variable but without a value return True 2)The ~/.theano.cfg file
if the key is in the variable with a value return the value 3)The value value provided in the get*() fct.
The last value found is the value returned.
if the key appears many times, we return the last value The THEANO_FLAGS environement variable should be a list of comma-separated [section.]option[=value] entries. If the section part is omited, their should be only one section with that contain the gived option.
The THEANO_FLAGS environement variable should be a list of comma-separated key[=value] entries.
""" """
f=os.getenv("THEANO_FLAGS", "")
ret = default def __init__(self):
key2=key+"=" d={} # no section
for fl in f.split(','): for k,v in default_.items():
if fl==key: if len(k.split('.'))==1:
ret = True d[k]=v
elif fl.startswith(key2):
ret = fl.split('=',1)[1] #set default value common for all section
self.config = ConfigParser.SafeConfigParser(d)
return ret
#set default value specific for each section
for k, v in default_.items():
sp = k.split('.',1)
if len(sp)==2:
if not self.config.has_section(sp[0]):
self.config.add_section(sp[0])
self.config.set(sp[0], sp[1], str(v))
#user config file override the default value
self.config.read(['site.cfg', os.path.expanduser('~/.theano.cfg')])
self.env_flags=os.getenv("THEANO_FLAGS","")
#The value in the env variable THEANO_FLAGS override the previous value
for flag in self.env_flags.split(','):
if not flag:
continue
sp=flag.split('=',1)
if len(sp)==1:
val=True
else:
val=sp[1]
val=str(val)
sp=sp[0].split('.',1)#option or section.option
if len(sp)==2:
self.config.set(sp[0],sp[1],val)
else:
found=0
for sec in self.config.sections():
for opt in self.config.options(sec):
if opt == sp[0]:
found+=1
section=sec
option=opt
if found==1:
self.config.set(section,option,val)
elif found>1:
raise Exception("Ambiguous option (%s) in THEANO_FLAGS"%(sp[0]))
def __getitem__(self, key):
return self.get(key)
def get(self, key, val=None):
#self.config.get(section, option, raw, vars=os.geteng('THEANO_DEFAULT'))
if val is not None:
return val
sp = key.split('.',1)
return self.config.get(sp[0],sp[1], False)
def getfloat(self, key, val=None):
if val is not None:
return float(val)
return float(self.get(key))
def getboolean(self, key, val=None):
if val is None:
val=self.get(key)
if val == "False" or val == "0" or not val:
val = False
else:
val = True
return val
def getint(self, key, val=None):
if val is not None:
return int(val)
return int(self.get(key))
config = TheanoConfig()
...@@ -8,7 +8,7 @@ _logger = logging.getLogger('theano.tensor.opt') ...@@ -8,7 +8,7 @@ _logger = logging.getLogger('theano.tensor.opt')
from theano import gof from theano import gof
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof.utils import MethodNotDefined, get_theano_flag from theano.gof.utils import MethodNotDefined, config
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from theano import scalar from theano import scalar
import basic as T import basic as T
...@@ -1339,7 +1339,7 @@ def local_elemwise_fusion(node): ...@@ -1339,7 +1339,7 @@ def local_elemwise_fusion(node):
# print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!" # print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!"
return n.outputs return n.outputs
if get_theano_flag('local_elemwise_fusion',False): if config.get('tensor_opt.local_elemwise_fusion'):
_logger.debug("enabling optimization: fusion elemwise") _logger.debug("enabling optimization: fusion elemwise")
register_specialize(local_elemwise_fusion) register_specialize(local_elemwise_fusion)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论