提交 b6c9f91d authored 作者: Frederic Bastien's avatar Frederic Bastien

add a parameter to thaeno flags to don't allow changing them after theano import.

上级 92ed9f15
......@@ -163,9 +163,14 @@ def AddConfigVar(name, doc, configparam, root=config):
_config_var_list.append(configparam)
class ConfigParam(object):
def __init__(self, default, filter=None):
def __init__(self, default, filter=None, allow_override=True):
"
If allow_override is False, we can't change the value after the import of Theano.
So the value should be the same during all the execution
"
self.default = default
self.filter=filter
self.allow_override = allow_override
# N.B. --
# self.fullname # set by AddConfigVar
# self.doc # set by AddConfigVar
......@@ -182,6 +187,8 @@ class ConfigParam(object):
return self.val
def __set__(self, cls, val):
if not self.allow_override and hasattr(self,'val'):
raise Exception("Can't change the value of this config parameter after initialization!")
#print "SETTING PARAM", self.fullname,(cls), val
if self.filter:
self.val = self.filter(val)
......@@ -191,7 +198,7 @@ class ConfigParam(object):
deleter=None
class EnumStr(ConfigParam):
def __init__(self, default, *options):
def __init__(self, default, *options, **kwargs):
self.default = default
self.all = (default,) + options
def filter(val):
......@@ -200,13 +207,14 @@ class EnumStr(ConfigParam):
else:
raise ValueError('Invalid value (%s) for configuration variable "%s". Legal options are %s'
% (val, self.fullname, self.all), val)
super(EnumStr, self).__init__(default, filter)
over = kwargs.get("allow_override", True)
super(EnumStr, self).__init__(default, filter, over)
def __str__(self):
return '%s (%s) ' % (self.fullname, self.all)
class TypedParam(ConfigParam):
def __init__(self, default, mytype, is_valid=None):
def __init__(self, default, mytype, is_valid=None, allow_override=True):
self.mytype = mytype
def filter(val):
casted_val = mytype(val)
......@@ -217,17 +225,17 @@ class TypedParam(ConfigParam):
raise ValueError('Invalid value (%s) for configuration variable "%s".'
% (val, self.fullname), val)
return casted_val
super(TypedParam, self).__init__(default, filter)
super(TypedParam, self).__init__(default, filter, allow_override=allow_override)
def __str__(self):
return '%s (%s) ' % (self.fullname, self.mytype)
def StrParam(default, is_valid=None):
return TypedParam(default, str, is_valid)
def IntParam(default, is_valid=None):
return TypedParam(default, int, is_valid)
def FloatParam(default, is_valid=None):
return TypedParam(default, float, is_valid)
def BoolParam(default, is_valid=None):
def StrParam(default, is_valid=None, allow_override=True):
return TypedParam(default, str, is_valid, allow_override=allow_override)
def IntParam(default, is_valid=None, allow_override=True):
return TypedParam(default, int, is_valid, allow_override=allow_override)
def FloatParam(default, is_valid=None, allow_override=True):
return TypedParam(default, float, is_valid, allow_override=allow_override)
def BoolParam(default, is_valid=None, allow_override=True):
#see comment at the beggining of this file.
def booltype(s):
if s in ['False','false','0', False]:
......@@ -242,4 +250,4 @@ def BoolParam(default, is_valid=None):
return False
if is_valid is None:
is_valid = is_valid_bool
return TypedParam(default, booltype, is_valid)
return TypedParam(default, booltype, is_valid, allow_override=allow_override)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论