提交 0530834d authored 作者: Frederic Bastien's avatar Frederic Bastien

replaced the theano.gof.utils.get_theano_flags fct by a class…

replaced the theano.gof.utils.get_theano_flags fct by a class theano.gof.utils.[TheanoConfig,config] that is a variante of ConfigParse.
上级 b1038e91
......@@ -5,7 +5,7 @@ from theano.gof.cutils import run_cthunk
from theano.compile.mode import Mode, register_mode, predefined_modes, predefined_linkers, predefined_optimizers, default_linker, default_optimizer
from theano.gof.cc import OpWiseCLinker
from theano import gof
from theano.gof.utils import get_theano_flag
from theano.gof.utils import config
class ProfileMode(Mode):
def __init__(self, linker=default_linker, optimizer=default_optimizer):
......@@ -73,7 +73,7 @@ class ProfileMode(Mode):
optimizer = predefined_optimizers[optimizer]
self._optimizer = optimizer
def print_summary(self, n_apply_to_print=15, n_ops_to_print=20):
def print_summary(self, n_apply_to_print=None, n_ops_to_print=None):
""" Print 3 summary that show where the time is spend. The first show an Apply-wise summary, the second show an Op-wise summary, the third show an type-Op-wise summary.
The Apply-wise summary print the timing information for the worst offending Apply nodes. This corresponds to individual Op applications within your graph which take the longest to execute (so if you use dot twice, you will see two entries there).
......@@ -87,8 +87,8 @@ class ProfileMode(Mode):
:param n_ops_to_print: the number of ops to print. Default 20, or n_apply_to_print flag.
"""
n_ops_to_print=int(get_theano_flag("n_ops_to_print", n_ops_to_print))
n_apply_to_print=int(get_theano_flag("n_apply_to_print", n_apply_to_print))
n_apply_to_print=config.getint("n_apply_to_print", n_apply_to_print)
n_ops_to_print=config.getint("n_ops_to_print", n_ops_to_print)
local_time = self.local_time[0]
compile_time = self.compile_time
......
......@@ -3,6 +3,7 @@
# import variable
import re, os
import ConfigParser
def hashgen():
hashgen.next += 1
......@@ -369,31 +370,103 @@ def type_guard(type1):
return new_f
return wrap
default_={
'ProfileMode.n_apply_to_print':15,
'ProfileMode.n_ops_to_print':20,
'tensor_opt.local_elemwise_fusion':False,
'scalar_basic.amdlibm':False,
}
def get_theano_flag(key, default=None):
"""Return the value for a key passed via the THEANO_FLAGS environment variable.
:type key: a string
:param key: the key to lookup
:type default: any
:param default: the value to be returned if the key is not present. (Default: None)
if the variable don't exist return `default`
if the key is not in the variable return `default`
if the key is in the variable but without a value return True
if the key is in the variable with a value return the value
class TheanoConfig(object):
"""Return the value for a key after parsing ~/.theano.cfg and
the THEANO_FLAGS environment variable.
We parse in that order the value to have:
1)the pair 'section.option':value in default_
2)The ~/.theano.cfg file
3)The value value provided in the get*() fct.
The last value found is the value returned.
The THEANO_FLAGS environement variable should be a list of comma-separated [section.]option[=value] entries. If the section part is omited, their should be only one section with that contain the gived option.
"""
if the key appears many times, we return the last value
def __init__(self):
d={} # no section
for k,v in default_.items():
if len(k.split('.'))==1:
d[k]=v
The THEANO_FLAGS environement variable should be a list of comma-separated key[=value] entries.
"""
f=os.getenv("THEANO_FLAGS", "")
ret = default
key2=key+"="
for fl in f.split(','):
if fl==key:
ret = True
elif fl.startswith(key2):
ret = fl.split('=',1)[1]
#set default value common for all section
self.config = ConfigParser.SafeConfigParser(d)
#set default value specific for each section
for k, v in default_.items():
sp = k.split('.',1)
if len(sp)==2:
if not self.config.has_section(sp[0]):
self.config.add_section(sp[0])
self.config.set(sp[0], sp[1], str(v))
#user config file override the default value
self.config.read(['site.cfg', os.path.expanduser('~/.theano.cfg')])
self.env_flags=os.getenv("THEANO_FLAGS","")
#The value in the env variable THEANO_FLAGS override the previous value
for flag in self.env_flags.split(','):
if not flag:
continue
sp=flag.split('=',1)
if len(sp)==1:
val=True
else:
val=sp[1]
val=str(val)
sp=sp[0].split('.',1)#option or section.option
if len(sp)==2:
self.config.set(sp[0],sp[1],val)
else:
found=0
for sec in self.config.sections():
for opt in self.config.options(sec):
if opt == sp[0]:
found+=1
section=sec
option=opt
if found==1:
self.config.set(section,option,val)
elif found>1:
raise Exception("Ambiguous option (%s) in THEANO_FLAGS"%(sp[0]))
return ret
def __getitem__(self, key):
return self.get(key)
def get(self, key, val=None):
#self.config.get(section, option, raw, vars=os.geteng('THEANO_DEFAULT'))
if val is not None:
return val
sp = key.split('.',1)
return self.config.get(sp[0],sp[1], False)
def getfloat(self, key, val=None):
if val is not None:
return float(val)
return float(self.get(key))
def getboolean(self, key, val=None):
if val is None:
val=self.get(key)
if val == "False" or val == "0" or not val:
val = False
else:
val = True
return val
def getint(self, key, val=None):
if val is not None:
return int(val)
return int(self.get(key))
config = TheanoConfig()
......@@ -8,7 +8,7 @@ _logger = logging.getLogger('theano.tensor.opt')
from theano import gof
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof.utils import MethodNotDefined, get_theano_flag
from theano.gof.utils import MethodNotDefined, config
from elemwise import Elemwise, DimShuffle
from theano import scalar
import basic as T
......@@ -1339,7 +1339,7 @@ def local_elemwise_fusion(node):
# print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!"
return n.outputs
if get_theano_flag('local_elemwise_fusion',False):
if config.get('tensor_opt.local_elemwise_fusion'):
_logger.debug("enabling optimization: fusion elemwise")
register_specialize(local_elemwise_fusion)
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论