提交 b6c9f91d authored 作者: Frederic Bastien's avatar Frederic Bastien

add a parameter to thaeno flags to don't allow changing them after theano import.

上级 92ed9f15
...@@ -163,9 +163,14 @@ def AddConfigVar(name, doc, configparam, root=config): ...@@ -163,9 +163,14 @@ def AddConfigVar(name, doc, configparam, root=config):
_config_var_list.append(configparam) _config_var_list.append(configparam)
class ConfigParam(object): class ConfigParam(object):
def __init__(self, default, filter=None): def __init__(self, default, filter=None, allow_override=True):
"
If allow_override is False, we can't change the value after the import of Theano.
So the value should be the same during all the execution
"
self.default = default self.default = default
self.filter=filter self.filter=filter
self.allow_override = allow_override
# N.B. -- # N.B. --
# self.fullname # set by AddConfigVar # self.fullname # set by AddConfigVar
# self.doc # set by AddConfigVar # self.doc # set by AddConfigVar
...@@ -182,6 +187,8 @@ class ConfigParam(object): ...@@ -182,6 +187,8 @@ class ConfigParam(object):
return self.val return self.val
def __set__(self, cls, val): def __set__(self, cls, val):
if not self.allow_override and hasattr(self,'val'):
raise Exception("Can't change the value of this config parameter after initialization!")
#print "SETTING PARAM", self.fullname,(cls), val #print "SETTING PARAM", self.fullname,(cls), val
if self.filter: if self.filter:
self.val = self.filter(val) self.val = self.filter(val)
...@@ -191,7 +198,7 @@ class ConfigParam(object): ...@@ -191,7 +198,7 @@ class ConfigParam(object):
deleter=None deleter=None
class EnumStr(ConfigParam): class EnumStr(ConfigParam):
def __init__(self, default, *options): def __init__(self, default, *options, **kwargs):
self.default = default self.default = default
self.all = (default,) + options self.all = (default,) + options
def filter(val): def filter(val):
...@@ -200,13 +207,14 @@ class EnumStr(ConfigParam): ...@@ -200,13 +207,14 @@ class EnumStr(ConfigParam):
else: else:
raise ValueError('Invalid value (%s) for configuration variable "%s". Legal options are %s' raise ValueError('Invalid value (%s) for configuration variable "%s". Legal options are %s'
% (val, self.fullname, self.all), val) % (val, self.fullname, self.all), val)
super(EnumStr, self).__init__(default, filter) over = kwargs.get("allow_override", True)
super(EnumStr, self).__init__(default, filter, over)
def __str__(self): def __str__(self):
return '%s (%s) ' % (self.fullname, self.all) return '%s (%s) ' % (self.fullname, self.all)
class TypedParam(ConfigParam): class TypedParam(ConfigParam):
def __init__(self, default, mytype, is_valid=None): def __init__(self, default, mytype, is_valid=None, allow_override=True):
self.mytype = mytype self.mytype = mytype
def filter(val): def filter(val):
casted_val = mytype(val) casted_val = mytype(val)
...@@ -217,17 +225,17 @@ class TypedParam(ConfigParam): ...@@ -217,17 +225,17 @@ class TypedParam(ConfigParam):
raise ValueError('Invalid value (%s) for configuration variable "%s".' raise ValueError('Invalid value (%s) for configuration variable "%s".'
% (val, self.fullname), val) % (val, self.fullname), val)
return casted_val return casted_val
super(TypedParam, self).__init__(default, filter) super(TypedParam, self).__init__(default, filter, allow_override=allow_override)
def __str__(self): def __str__(self):
return '%s (%s) ' % (self.fullname, self.mytype) return '%s (%s) ' % (self.fullname, self.mytype)
def StrParam(default, is_valid=None): def StrParam(default, is_valid=None, allow_override=True):
return TypedParam(default, str, is_valid) return TypedParam(default, str, is_valid, allow_override=allow_override)
def IntParam(default, is_valid=None): def IntParam(default, is_valid=None, allow_override=True):
return TypedParam(default, int, is_valid) return TypedParam(default, int, is_valid, allow_override=allow_override)
def FloatParam(default, is_valid=None): def FloatParam(default, is_valid=None, allow_override=True):
return TypedParam(default, float, is_valid) return TypedParam(default, float, is_valid, allow_override=allow_override)
def BoolParam(default, is_valid=None): def BoolParam(default, is_valid=None, allow_override=True):
#see comment at the beggining of this file. #see comment at the beggining of this file.
def booltype(s): def booltype(s):
if s in ['False','false','0', False]: if s in ['False','false','0', False]:
...@@ -242,4 +250,4 @@ def BoolParam(default, is_valid=None): ...@@ -242,4 +250,4 @@ def BoolParam(default, is_valid=None):
return False return False
if is_valid is None: if is_valid is None:
is_valid = is_valid_bool is_valid = is_valid_bool
return TypedParam(default, booltype, is_valid) return TypedParam(default, booltype, is_valid, allow_override=allow_override)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论