提交 201b4610 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a way to determine if an option has the default value or not.

Also starts adressing issue #3468 a bit.
上级 55f671d3
...@@ -102,7 +102,7 @@ def change_flags(**kwargs): ...@@ -102,7 +102,7 @@ def change_flags(**kwargs):
l = [v for v in theano.configparser._config_var_list l = [v for v in theano.configparser._config_var_list
if v.fullname == k] if v.fullname == k]
assert len(l) == 1 assert len(l) == 1
old_val[k] = l[0].__get__() old_val[k] = l[0].__get__(True, None)
try: try:
for k in kwargs: for k in kwargs:
l = [v for v in theano.configparser._config_var_list l = [v for v in theano.configparser._config_var_list
...@@ -167,7 +167,7 @@ def _config_print(thing, buf): ...@@ -167,7 +167,7 @@ def _config_print(thing, buf):
for cv in _config_var_list: for cv in _config_var_list:
print(cv, file=buf) print(cv, file=buf)
print(" Doc: ", cv.doc, file=buf) print(" Doc: ", cv.doc, file=buf)
print(" Value: ", cv.__get__(), file=buf) print(" Value: ", cv.__get__(True, None), file=buf)
print("", file=buf) print("", file=buf)
...@@ -182,7 +182,7 @@ def get_config_md5(): ...@@ -182,7 +182,7 @@ def get_config_md5():
all_opts = sorted([c for c in _config_var_list if c.in_c_key], all_opts = sorted([c for c in _config_var_list if c.in_c_key],
key=lambda cv: cv.fullname) key=lambda cv: cv.fullname)
return theano.gof.utils.hash_from_code('\n'.join( return theano.gof.utils.hash_from_code('\n'.join(
['%s = %s' % (cv.fullname, cv.__get__()) for cv in all_opts])) ['%s = %s' % (cv.fullname, cv.__get__(True, None)) for cv in all_opts]))
class TheanoConfigParser(object): class TheanoConfigParser(object):
...@@ -270,14 +270,14 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): ...@@ -270,14 +270,14 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
# Trigger a read of the value from config files and env vars # Trigger a read of the value from config files and env vars
# This allow to filter wrong value from the user. # This allow to filter wrong value from the user.
if not callable(configparam.default): if not callable(configparam.default):
configparam.__get__() configparam.__get__(root, type(root))
else: else:
# We do not want to evaluate now the default value # We do not want to evaluate now the default value
# when it is a callable. # when it is a callable.
try: try:
fetch_val_for_key(configparam.fullname) fetch_val_for_key(configparam.fullname)
# The user provided a value, filter it now. # The user provided a value, filter it now.
configparam.__get__() configparam.__get__(root, type(root))
except KeyError: except KeyError:
pass pass
setattr(root.__class__, sections[0], configparam) setattr(root.__class__, sections[0], configparam)
...@@ -294,6 +294,7 @@ class ConfigParam(object): ...@@ -294,6 +294,7 @@ class ConfigParam(object):
self.default = default self.default = default
self.filter = filter self.filter = filter
self.allow_override = allow_override self.allow_override = allow_override
self.is_default = True
# N.B. -- # N.B. --
# self.fullname # set by AddConfigVar # self.fullname # set by AddConfigVar
# self.doc # set by AddConfigVar # self.doc # set by AddConfigVar
...@@ -304,10 +305,13 @@ class ConfigParam(object): ...@@ -304,10 +305,13 @@ class ConfigParam(object):
# Calling `filter` here may actually be harmful if the default value is # Calling `filter` here may actually be harmful if the default value is
# invalid and causes a crash or has unwanted side effects. # invalid and causes a crash or has unwanted side effects.
def __get__(self, *args): def __get__(self, cls, type_):
if cls is None:
return self
if not hasattr(self, 'val'): if not hasattr(self, 'val'):
try: try:
val_str = fetch_val_for_key(self.fullname) val_str = fetch_val_for_key(self.fullname)
self.is_default = False
except KeyError: except KeyError:
if callable(self.default): if callable(self.default):
val_str = self.default() val_str = self.default()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论