提交 11e7b1fe authored 作者: Josh Bleecher Snyder's avatar Josh Bleecher Snyder

Make platform-dependent compiledir configurable

上级 724a2569
...@@ -60,8 +60,16 @@ def config_files_from_theanorc(): ...@@ -60,8 +60,16 @@ def config_files_from_theanorc():
rval.append(os.path.expanduser('~/.theanorc.txt')) rval.append(os.path.expanduser('~/.theanorc.txt'))
return rval return rval
theano_cfg = ConfigParser.SafeConfigParser({'USER':os.getenv("USER", os.path.split(os.path.expanduser('~'))[-1])})
theano_cfg.read(config_files_from_theanorc()) config_files = config_files_from_theanorc()
theano_cfg = ConfigParser.SafeConfigParser({'USER': os.getenv("USER", os.path.split(os.path.expanduser('~'))[-1])})
theano_cfg.read(config_files)
# Having a raw version of the config around as well enables us to pass
# through config values that contain format strings.
# The time required to parse the config twice is negligible.
theano_raw_cfg = ConfigParser.RawConfigParser()
theano_raw_cfg.read(config_files)
def fetch_val_for_key(key): def fetch_val_for_key(key):
"""Return the overriding config value for a key. """Return the overriding config value for a key.
...@@ -91,8 +99,11 @@ def fetch_val_for_key(key): ...@@ -91,8 +99,11 @@ def fetch_val_for_key(key):
section, option = key_tokens section, option = key_tokens
else: else:
section, option = 'global', key section, option = 'global', key
try:
try: try:
return theano_cfg.get(section, option) return theano_cfg.get(section, option)
except ConfigParser.InterpolationError:
return theano_raw_cfg.get(section, option)
except (ConfigParser.NoOptionError, ConfigParser.NoSectionError): except (ConfigParser.NoOptionError, ConfigParser.NoSectionError):
raise KeyError(key) raise KeyError(key)
......
...@@ -4,19 +4,32 @@ import os ...@@ -4,19 +4,32 @@ import os
import platform import platform
import re import re
import sys import sys
import textwrap
import theano import theano
from theano.configparser import config, AddConfigVar, ConfigParam, StrParam from theano.configparser import config, AddConfigVar, ConfigParam, StrParam
compiledir_format_dict = {"platform": platform.platform(),
"processor": platform.processor(),
"python_version": platform.python_version(),
"theano_version": theano.__version__,
}
compiledir_format_keys = ", ".join(compiledir_format_dict.keys())
default_compiledir_format = "compiledir_%(platform)s-%(processor)s-%(python_version)s"
AddConfigVar("compiledir_format",
textwrap.fill(textwrap.dedent("""\
Format string for platform-dependent compiled
module subdirectory (relative to base_compiledir).
Available keys: %s. Defaults to %r.
""" % (compiledir_format_keys, default_compiledir_format))),
StrParam(default_compiledir_format, allow_override=False))
def default_compiledirname(): def default_compiledirname():
platform_id = '-'.join([ formatted = config.compiledir_format % compiledir_format_dict
platform.platform(), safe = re.sub("[\(\)\s,]+", "_", formatted)
platform.processor(), return safe
platform.python_version(),
theano.__version__])
platform_id = re.sub("[\(\)\s,]+", "_", platform_id)
return 'compiledir_' + platform_id
def filter_compiledir(path): def filter_compiledir(path):
...@@ -78,12 +91,11 @@ else: ...@@ -78,12 +91,11 @@ else:
AddConfigVar('base_compiledir', AddConfigVar('base_compiledir',
"arch-independent cache directory for compiled modules", "platform-independent root directory for compiled modules",
StrParam(default_base_compiledir, allow_override=False)) StrParam(default_base_compiledir, allow_override=False))
AddConfigVar('compiledir', AddConfigVar('compiledir',
"arch-dependent cache directory for compiled modules", "platform-dependent cache directory for compiled modules",
ConfigParam( ConfigParam(
os.path.join( os.path.join(
os.path.expanduser(config.base_compiledir), os.path.expanduser(config.base_compiledir),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论