提交 a0eb0ad0 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Thomas Wiecki

Move constants and helper functions in configdefaults

+ BITWIDTHs are now constants in utils. + Lambdas and local functions for default/filter/validate are now module-level local functions. This fixes pickleability (closes #240).
上级 cd98c61a
...@@ -242,16 +242,16 @@ def test_no_more_dotting(): ...@@ -242,16 +242,16 @@ def test_no_more_dotting():
def test_mode_apply(): def test_mode_apply():
assert configdefaults.filter_mode("DebugMode") == "DebugMode" assert configdefaults._filter_mode("DebugMode") == "DebugMode"
with pytest.raises(ValueError, match="Expected one of"): with pytest.raises(ValueError, match="Expected one of"):
configdefaults.filter_mode("not_a_mode") configdefaults._filter_mode("not_a_mode")
# test with theano.Mode instance # test with theano.Mode instance
import theano.compile.mode import theano.compile.mode
assert ( assert (
configdefaults.filter_mode(theano.compile.mode.FAST_COMPILE) configdefaults._filter_mode(theano.compile.mode.FAST_COMPILE)
== theano.compile.mode.FAST_COMPILE == theano.compile.mode.FAST_COMPILE
) )
......
import distutils.spawn
import errno import errno
import logging import logging
import os import os
import platform import platform
import re import re
import socket import socket
import struct
import sys import sys
import textwrap import textwrap
...@@ -23,6 +23,8 @@ from theano.configparser import ( ...@@ -23,6 +23,8 @@ from theano.configparser import (
StrParam, StrParam,
) )
from theano.utils import ( from theano.utils import (
LOCAL_BITWIDTH,
PYTHON_INT_BITWIDTH,
call_subprocess_Popen, call_subprocess_Popen,
maybe_add_to_os_environ_pathlist, maybe_add_to_os_environ_pathlist,
output_subprocess_Popen, output_subprocess_Popen,
...@@ -50,14 +52,14 @@ def get_cuda_root(): ...@@ -50,14 +52,14 @@ def get_cuda_root():
def default_cuda_include(): def default_cuda_include():
if theano.config.cuda__root: if config.cuda__root:
return os.path.join(theano.config.cuda__root, "include") return os.path.join(config.cuda__root, "include")
return "" return ""
def default_dnn_base_path(): def default_dnn_base_path():
# We want to default to the cuda root if cudnn is installed there # We want to default to the cuda root if cudnn is installed there
root = theano.config.cuda__root root = config.cuda__root
# The include doesn't change location between OS. # The include doesn't change location between OS.
if root and os.path.exists(os.path.join(root, "include", "cudnn.h")): if root and os.path.exists(os.path.join(root, "include", "cudnn.h")):
return root return root
...@@ -65,34 +67,34 @@ def default_dnn_base_path(): ...@@ -65,34 +67,34 @@ def default_dnn_base_path():
def default_dnn_inc_path(): def default_dnn_inc_path():
if theano.config.dnn__base_path != "": if config.dnn__base_path != "":
return os.path.join(theano.config.dnn__base_path, "include") return os.path.join(config.dnn__base_path, "include")
return "" return ""
def default_dnn_lib_path(): def default_dnn_lib_path():
if theano.config.dnn__base_path != "": if config.dnn__base_path != "":
if sys.platform == "win32": if sys.platform == "win32":
path = os.path.join(theano.config.dnn__base_path, "lib", "x64") path = os.path.join(config.dnn__base_path, "lib", "x64")
elif sys.platform == "darwin": elif sys.platform == "darwin":
path = os.path.join(theano.config.dnn__base_path, "lib") path = os.path.join(config.dnn__base_path, "lib")
else: else:
# This is linux # This is linux
path = os.path.join(theano.config.dnn__base_path, "lib64") path = os.path.join(config.dnn__base_path, "lib64")
return path return path
return "" return ""
def default_dnn_bin_path(): def default_dnn_bin_path():
if theano.config.dnn__base_path != "": if config.dnn__base_path != "":
if sys.platform == "win32": if sys.platform == "win32":
return os.path.join(theano.config.dnn__base_path, "bin") return os.path.join(config.dnn__base_path, "bin")
else: else:
return theano.config.dnn__library_path return config.dnn__library_path
return "" return ""
def filter_mode(val): def _filter_mode(val):
# Do not add FAST_RUN_NOGC to this list (nor any other ALL CAPS shortcut). # Do not add FAST_RUN_NOGC to this list (nor any other ALL CAPS shortcut).
# The way to get FAST_RUN_NOGC is with the flag 'linker=c|py_nogc'. # The way to get FAST_RUN_NOGC is with the flag 'linker=c|py_nogc'.
# The old all capital letter way of working is deprecated as it is not # The old all capital letter way of working is deprecated as it is not
...@@ -121,7 +123,7 @@ def filter_mode(val): ...@@ -121,7 +123,7 @@ def filter_mode(val):
) )
def warn_cxx(val): def _warn_cxx(val):
"""We only support clang++ as otherwise we hit strange g++/OSX bugs.""" """We only support clang++ as otherwise we hit strange g++/OSX bugs."""
if sys.platform == "darwin" and val and "clang++" not in val: if sys.platform == "darwin" and val and "clang++" not in val:
_logger.warning( _logger.warning(
...@@ -131,14 +133,14 @@ def warn_cxx(val): ...@@ -131,14 +133,14 @@ def warn_cxx(val):
return True return True
def split_version(version): def _split_version(version):
""" """
Take version as a dot-separated string, return a tuple of int Take version as a dot-separated string, return a tuple of int
""" """
return tuple(int(i) for i in version.split(".")) return tuple(int(i) for i in version.split("."))
def warn_default(version): def _warn_default(version):
""" """
Return True iff we should warn about bugs fixed after a given version. Return True iff we should warn about bugs fixed after a given version.
""" """
...@@ -146,12 +148,12 @@ def warn_default(version): ...@@ -146,12 +148,12 @@ def warn_default(version):
return True return True
if config.warn__ignore_bug_before == "all": if config.warn__ignore_bug_before == "all":
return False return False
if split_version(config.warn__ignore_bug_before) >= split_version(version): if _split_version(config.warn__ignore_bug_before) >= _split_version(version):
return False return False
return True return True
def good_seed_param(seed): def _good_seem_param(seed):
if seed == "random": if seed == "random":
return True return True
try: try:
...@@ -161,7 +163,7 @@ def good_seed_param(seed): ...@@ -161,7 +163,7 @@ def good_seed_param(seed):
return True return True
def is_valid_check_preallocated_output_param(param): def _is_valid_check_preallocated_output_param(param):
if not isinstance(param, str): if not isinstance(param, str):
return False return False
valid = [ valid = [
...@@ -181,36 +183,10 @@ def is_valid_check_preallocated_output_param(param): ...@@ -181,36 +183,10 @@ def is_valid_check_preallocated_output_param(param):
def _timeout_default(): def _timeout_default():
return theano.config.compile__wait * 24 return config.compile__wait * 24
def local_bitwidth(): def _filter_vm_lazy(val):
"""
Return 32 for 32bit arch, 64 for 64bit arch.
By "architecture", we mean the size of memory pointers (size_t in C),
*not* the size of long int, as it can be different.
"""
# Note that according to Python documentation, `platform.architecture()` is
# not reliable on OS X with universal binaries.
# Also, sys.maxsize does not exist in Python < 2.6.
# 'P' denotes a void*, and the size is expressed in bytes.
return struct.calcsize("P") * 8
def python_int_bitwidth():
"""
Return the bit width of Python int (C long int).
Note that it can be different from the size of a memory pointer.
"""
# 'l' denotes a C long int, and the size is expressed in bytes.
return struct.calcsize("l") * 8
def filter_vm_lazy(val):
if val == "False" or val is False: if val == "False" or val is False:
return False return False
elif val == "True" or val is True: elif val == "True" or val is True:
...@@ -605,12 +581,20 @@ def add_magma_configvars(): ...@@ -605,12 +581,20 @@ def add_magma_configvars():
) )
def _is_gt_0(x):
return x > 0
def _is_greater_or_equal_0(x):
return x >= 0
def add_compile_configvars(): def add_compile_configvars():
config.add( config.add(
"mode", "mode",
"Default compilation mode", "Default compilation mode",
ConfigParam("Mode", apply=filter_mode), ConfigParam("Mode", apply=_filter_mode),
in_c_key=False, in_c_key=False,
) )
...@@ -655,13 +639,10 @@ def add_compile_configvars(): ...@@ -655,13 +639,10 @@ def add_compile_configvars():
# Try to find the full compiler path from the name # Try to find the full compiler path from the name
if param != "": if param != "":
import distutils.spawn
newp = distutils.spawn.find_executable(param) newp = distutils.spawn.find_executable(param)
if newp is not None: if newp is not None:
param = newp param = newp
del newp del newp
del distutils
# to support path that includes spaces, we need to wrap it with double quotes on Windows # to support path that includes spaces, we need to wrap it with double quotes on Windows
if param and os.name == "nt": if param and os.name == "nt":
...@@ -673,7 +654,7 @@ def add_compile_configvars(): ...@@ -673,7 +654,7 @@ def add_compile_configvars():
" supported, but supporting additional compilers should not be " " supported, but supporting additional compilers should not be "
"too difficult. " "too difficult. "
"If it is empty, no C++ code is compiled.", "If it is empty, no C++ code is compiled.",
StrParam(param, validate=warn_cxx), StrParam(param, validate=_warn_cxx),
in_c_key=False, in_c_key=False,
) )
del param del param
...@@ -823,7 +804,7 @@ def add_compile_configvars(): ...@@ -823,7 +804,7 @@ def add_compile_configvars():
config.add( config.add(
"compile__wait", "compile__wait",
"""Time to wait before retrying to acquire the compile lock.""", """Time to wait before retrying to acquire the compile lock.""",
IntParam(5, validate=lambda i: i > 0, mutable=False), IntParam(5, validate=_is_gt_0, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -834,7 +815,7 @@ def add_compile_configvars(): ...@@ -834,7 +815,7 @@ def add_compile_configvars():
lock is held by the same owner *and* has not been 'refreshed' by this lock is held by the same owner *and* has not been 'refreshed' by this
owner for more than this period. Refreshes are done every half timeout owner for more than this period. Refreshes are done every half timeout
period for running processes.""", period for running processes.""",
IntParam(_timeout_default, validate=lambda i: i >= 0, mutable=False), IntParam(_timeout_default, validate=_is_greater_or_equal_0, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -848,6 +829,10 @@ def add_compile_configvars(): ...@@ -848,6 +829,10 @@ def add_compile_configvars():
) )
def _is_valid_cmp_sloppy(v):
return v in (0, 1, 2)
def add_tensor_configvars(): def add_tensor_configvars():
# This flag is used when we import Theano to initialize global variables. # This flag is used when we import Theano to initialize global variables.
...@@ -857,7 +842,7 @@ def add_tensor_configvars(): ...@@ -857,7 +842,7 @@ def add_tensor_configvars():
config.add( config.add(
"tensor__cmp_sloppy", "tensor__cmp_sloppy",
"Relax tensor._allclose (0) not at all, (1) a bit, (2) more", "Relax tensor._allclose (0) not at all, (1) a bit, (2) more",
IntParam(0, lambda i: i in (0, 1, 2), mutable=False), IntParam(0, _is_valid_cmp_sloppy, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -1080,6 +1065,14 @@ def add_error_and_warning_configvars(): ...@@ -1080,6 +1065,14 @@ def add_error_and_warning_configvars():
) )
def _has_cxx():
return bool(config.cxx)
def _is_valid_check_strides(v):
return v in (0, 1, 2)
def add_testvalue_and_checking_configvars(): def add_testvalue_and_checking_configvars():
config.add( config.add(
"print_test_value", "print_test_value",
...@@ -1155,14 +1148,14 @@ def add_testvalue_and_checking_configvars(): ...@@ -1155,14 +1148,14 @@ def add_testvalue_and_checking_configvars():
config.add( config.add(
"DebugMode__patience", "DebugMode__patience",
"Optimize graph this many times to detect inconsistency", "Optimize graph this many times to detect inconsistency",
IntParam(10, lambda i: i > 0), IntParam(10, _is_gt_0),
in_c_key=False, in_c_key=False,
) )
config.add( config.add(
"DebugMode__check_c", "DebugMode__check_c",
"Run C implementations where possible", "Run C implementations where possible",
BoolParam(lambda: bool(theano.config.cxx)), BoolParam(_has_cxx),
in_c_key=False, in_c_key=False,
) )
...@@ -1186,7 +1179,8 @@ def add_testvalue_and_checking_configvars(): ...@@ -1186,7 +1179,8 @@ def add_testvalue_and_checking_configvars():
"Check that Python- and C-produced ndarrays have same strides. " "Check that Python- and C-produced ndarrays have same strides. "
"On difference: (0) - ignore, (1) warn, or (2) raise error" "On difference: (0) - ignore, (1) warn, or (2) raise error"
), ),
IntParam(0, lambda i: i in (0, 1, 2)), # TODO: make this an Enum setting
IntParam(0, _is_valid_check_strides),
in_c_key=False, in_c_key=False,
) )
...@@ -1213,7 +1207,7 @@ def add_testvalue_and_checking_configvars(): ...@@ -1213,7 +1207,7 @@ def add_testvalue_and_checking_configvars():
'"wrong_size" (larger and smaller dimensions), and ' '"wrong_size" (larger and smaller dimensions), and '
'"ALL" (all of the above).' '"ALL" (all of the above).'
), ),
StrParam("", validate=is_valid_check_preallocated_output_param), StrParam("", validate=_is_valid_check_preallocated_output_param),
in_c_key=False, in_c_key=False,
) )
...@@ -1226,7 +1220,7 @@ def add_testvalue_and_checking_configvars(): ...@@ -1226,7 +1220,7 @@ def add_testvalue_and_checking_configvars():
"to reduce memory or time usage, but it is advised to keep a " "to reduce memory or time usage, but it is advised to keep a "
"minimum of 2." "minimum of 2."
), ),
IntParam(4, lambda i: i > 0), IntParam(4, _is_gt_0),
in_c_key=False, in_c_key=False,
) )
...@@ -1240,21 +1234,21 @@ def add_testvalue_and_checking_configvars(): ...@@ -1240,21 +1234,21 @@ def add_testvalue_and_checking_configvars():
config.add( config.add(
"profiling__n_apply", "profiling__n_apply",
"Number of Apply instances to print by default", "Number of Apply instances to print by default",
IntParam(20, lambda i: i > 0), IntParam(20, _is_gt_0),
in_c_key=False, in_c_key=False,
) )
config.add( config.add(
"profiling__n_ops", "profiling__n_ops",
"Number of Ops to print by default", "Number of Ops to print by default",
IntParam(20, lambda i: i > 0), IntParam(20, _is_gt_0),
in_c_key=False, in_c_key=False,
) )
config.add( config.add(
"profiling__output_line_width", "profiling__output_line_width",
"Max line width for the profiling output", "Max line width for the profiling output",
IntParam(512, lambda i: i > 0), IntParam(512, _is_gt_0),
in_c_key=False, in_c_key=False,
) )
...@@ -1262,7 +1256,7 @@ def add_testvalue_and_checking_configvars(): ...@@ -1262,7 +1256,7 @@ def add_testvalue_and_checking_configvars():
"profiling__min_memory_size", "profiling__min_memory_size",
"""For the memory profile, do not print Apply nodes if the size """For the memory profile, do not print Apply nodes if the size
of their outputs (in bytes) is lower than this threshold""", of their outputs (in bytes) is lower than this threshold""",
IntParam(1024, lambda i: i >= 0), IntParam(1024, _is_greater_or_equal_0),
in_c_key=False, in_c_key=False,
) )
...@@ -1482,7 +1476,7 @@ def add_vm_configvars(): ...@@ -1482,7 +1476,7 @@ def add_vm_configvars():
" auto detect if lazy evaluation is needed and use the appropriate" " auto detect if lazy evaluation is needed and use the appropriate"
" version. If lazy is True/False, force the version used between" " version. If lazy is True/False, force the version used between"
" Loop/LoopGC and Stack.", " Loop/LoopGC and Stack.",
ConfigParam("None", apply=filter_vm_lazy), ConfigParam("None", apply=_filter_vm_lazy),
in_c_key=False, in_c_key=False,
) )
...@@ -1505,7 +1499,7 @@ def add_deprecated_configvars(): ...@@ -1505,7 +1499,7 @@ def add_deprecated_configvars():
"unittests__rseed", "unittests__rseed",
"Seed to use for randomized unit tests. " "Seed to use for randomized unit tests. "
"Special value 'random' means using a seed of None.", "Special value 'random' means using a seed of None.",
StrParam(666, validate=good_seed_param), StrParam(666, validate=_good_seem_param),
in_c_key=False, in_c_key=False,
) )
...@@ -1514,7 +1508,7 @@ def add_deprecated_configvars(): ...@@ -1514,7 +1508,7 @@ def add_deprecated_configvars():
"warn__identify_1pexp_bug", "warn__identify_1pexp_bug",
"Warn if Theano versions prior to 7987b51 (2011-12-18) could have " "Warn if Theano versions prior to 7987b51 (2011-12-18) could have "
"yielded a wrong result due to a bug in the is_1pexp function", "yielded a wrong result due to a bug in the is_1pexp function",
BoolParam(warn_default("0.4.1")), BoolParam(_warn_default("0.4.1")),
in_c_key=False, in_c_key=False,
) )
# TODO: this setting is not used anywhere # TODO: this setting is not used anywhere
...@@ -1542,7 +1536,7 @@ def add_deprecated_configvars(): ...@@ -1542,7 +1536,7 @@ def add_deprecated_configvars():
"theano.tensor.nnet.nnet.local_argmax_pushdown optimization. " "theano.tensor.nnet.nnet.local_argmax_pushdown optimization. "
"Was fixed 27 may 2010" "Was fixed 27 may 2010"
), ),
BoolParam(warn_default("0.3")), BoolParam(_warn_default("0.3")),
in_c_key=False, in_c_key=False,
) )
...@@ -1553,7 +1547,7 @@ def add_deprecated_configvars(): ...@@ -1553,7 +1547,7 @@ def add_deprecated_configvars():
"silent bug with GpuSum pattern 01,011 and 0111 when the first " "silent bug with GpuSum pattern 01,011 and 0111 when the first "
"dimensions was bigger then 4096. Was fixed 31 may 2010" "dimensions was bigger then 4096. Was fixed 31 may 2010"
), ),
BoolParam(warn_default("0.3")), BoolParam(_warn_default("0.3")),
in_c_key=False, in_c_key=False,
) )
...@@ -1566,7 +1560,7 @@ def add_deprecated_configvars(): ...@@ -1566,7 +1560,7 @@ def add_deprecated_configvars():
"sums in the graph, bad code was generated. " "sums in the graph, bad code was generated. "
"Was fixed 2 August 2010" "Was fixed 2 August 2010"
), ),
BoolParam(warn_default("0.3")), BoolParam(_warn_default("0.3")),
in_c_key=False, in_c_key=False,
) )
...@@ -1578,7 +1572,7 @@ def add_deprecated_configvars(): ...@@ -1578,7 +1572,7 @@ def add_deprecated_configvars():
"would have given incorrect result. This bug was triggered by " "would have given incorrect result. This bug was triggered by "
"sum of division of dimshuffled tensors." "sum of division of dimshuffled tensors."
), ),
BoolParam(warn_default("0.3")), BoolParam(_warn_default("0.3")),
in_c_key=False, in_c_key=False,
) )
...@@ -1587,7 +1581,7 @@ def add_deprecated_configvars(): ...@@ -1587,7 +1581,7 @@ def add_deprecated_configvars():
"Warn if previous versions of Theano (before 0.5rc2) could have given " "Warn if previous versions of Theano (before 0.5rc2) could have given "
"incorrect results when indexing into a subtensor with negative " "incorrect results when indexing into a subtensor with negative "
"stride (for instance, for instance, x[a:b:-1][c]).", "stride (for instance, for instance, x[a:b:-1][c]).",
BoolParam(warn_default("0.5")), BoolParam(_warn_default("0.5")),
in_c_key=False, in_c_key=False,
) )
...@@ -1596,7 +1590,7 @@ def add_deprecated_configvars(): ...@@ -1596,7 +1590,7 @@ def add_deprecated_configvars():
"Warn if previous versions of Theano (before 0.6) could have given " "Warn if previous versions of Theano (before 0.6) could have given "
"incorrect results when moving to the gpu " "incorrect results when moving to the gpu "
"set_subtensor(x[int vector], new_value)", "set_subtensor(x[int vector], new_value)",
BoolParam(warn_default("0.6")), BoolParam(_warn_default("0.6")),
in_c_key=False, in_c_key=False,
) )
...@@ -1619,7 +1613,7 @@ def add_deprecated_configvars(): ...@@ -1619,7 +1613,7 @@ def add_deprecated_configvars():
"Warn we use the new signal.conv2d() when its interface" "Warn we use the new signal.conv2d() when its interface"
" changed mid June 2014" " changed mid June 2014"
), ),
BoolParam(warn_default("0.7")), BoolParam(_warn_default("0.7")),
in_c_key=False, in_c_key=False,
) )
...@@ -1636,7 +1630,7 @@ def add_deprecated_configvars(): ...@@ -1636,7 +1630,7 @@ def add_deprecated_configvars():
"did not check the reduction axis. So if the " "did not check the reduction axis. So if the "
"reduction axis was not 0, you got a wrong answer." "reduction axis was not 0, you got a wrong answer."
), ),
BoolParam(warn_default("0.7")), BoolParam(_warn_default("0.7")),
in_c_key=False, in_c_key=False,
) )
...@@ -1648,7 +1642,7 @@ def add_deprecated_configvars(): ...@@ -1648,7 +1642,7 @@ def add_deprecated_configvars():
"when using some patterns of advanced indexing (indexing with " "when using some patterns of advanced indexing (indexing with "
"one vector or matrix of ints)." "one vector or matrix of ints)."
), ),
BoolParam(warn_default("0.7")), BoolParam(_warn_default("0.7")),
in_c_key=False, in_c_key=False,
) )
...@@ -1657,7 +1651,7 @@ def add_deprecated_configvars(): ...@@ -1657,7 +1651,7 @@ def add_deprecated_configvars():
"Warn when using `tensor.round` with the default mode. " "Warn when using `tensor.round` with the default mode. "
"Round changed its default from `half_away_from_zero` to " "Round changed its default from `half_away_from_zero` to "
"`half_to_even` to have the same default as NumPy.", "`half_to_even` to have the same default as NumPy.",
BoolParam(warn_default("0.9")), BoolParam(_warn_default("0.9")),
in_c_key=False, in_c_key=False,
) )
...@@ -1667,7 +1661,7 @@ def add_deprecated_configvars(): ...@@ -1667,7 +1661,7 @@ def add_deprecated_configvars():
"given incorrect results when computing " "given incorrect results when computing "
"inc_subtensor(zeros[idx], x)[idx], when idx is an array of integers " "inc_subtensor(zeros[idx], x)[idx], when idx is an array of integers "
"with duplicated values.", "with duplicated values.",
BoolParam(warn_default("0.10")), BoolParam(_warn_default("0.10")),
in_c_key=False, in_c_key=False,
) )
...@@ -1696,24 +1690,108 @@ def add_scan_configvars(): ...@@ -1696,24 +1690,108 @@ def add_scan_configvars():
) )
def _get_default_gpuarray__cache_path():
return os.path.join(config.compiledir, "gpuarray_kernels")
def _default_compiledirname():
formatted = config.compiledir_format % _compiledir_format_dict
safe = re.sub(r"[\(\)\s,]+", "_", formatted)
return safe
def _filter_base_compiledir(path):
# Expand '~' in path
return os.path.expanduser(str(path))
def _filter_compiledir(path):
# Expand '~' in path
path = os.path.expanduser(path)
# Turn path into the 'real' path. This ensures that:
# 1. There is no relative path, which would fail e.g. when trying to
# import modules from the compile dir.
# 2. The path is stable w.r.t. e.g. symlinks (which makes it easier
# to re-use compiled modules).
path = os.path.realpath(path)
if os.access(path, os.F_OK): # Do it exist?
if not os.access(path, os.R_OK | os.W_OK | os.X_OK):
# If it exist we need read, write and listing access
raise ValueError(
f"compiledir '{path}' exists but you don't have read, write"
" or listing permissions."
)
else:
try:
os.makedirs(path, 0o770) # read-write-execute for user and group
except OSError as e:
# Maybe another parallel execution of theano was trying to create
# the same directory at the same time.
if e.errno != errno.EEXIST:
raise ValueError(
"Unable to create the compiledir directory"
f" '{path}'. Check the permissions."
)
# PROBLEM: sometimes the initial approach based on
# os.system('touch') returned -1 for an unknown reason; the
# alternate approach here worked in all cases... it was weird.
# No error should happen as we checked the permissions.
init_file = os.path.join(path, "__init__.py")
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
else:
e.args += (f"{path} exist? {os.path.exists(path)}",)
raise
return path
def _get_home_dir():
"""
Return location of the user's home directory.
"""
home = os.getenv("HOME")
if home is None:
# This expanduser usually works on Windows (see discussion on
# theano-users, July 13 2010).
home = os.path.expanduser("~")
if home == "~":
# This might happen when expanduser fails. Although the cause of
# failure is a mystery, it has been seen on some Windows system.
home = os.getenv("USERPROFILE")
assert home is not None
return home
_compiledir_format_dict = {
"platform": platform.platform(),
"processor": platform.processor(),
"python_version": platform.python_version(),
"python_bitwidth": LOCAL_BITWIDTH,
"python_int_bitwidth": PYTHON_INT_BITWIDTH,
"theano_version": theano.__version__,
"numpy_version": np.__version__,
"gxx_version": "xxx",
"hostname": socket.gethostname(),
}
def _default_compiledir():
return os.path.join(config.base_compiledir, _default_compiledirname())
def add_caching_dir_configvars(): def add_caching_dir_configvars():
compiledir_format_dict = { _compiledir_format_dict["gxx_version"] = (gcc_version_str.replace(" ", "_"),)
"platform": platform.platform(), _compiledir_format_dict["short_platform"] = short_platform()
"processor": platform.processor(),
"python_version": platform.python_version(),
"python_bitwidth": local_bitwidth(),
"python_int_bitwidth": python_int_bitwidth(),
"theano_version": theano.__version__,
"numpy_version": np.__version__,
"gxx_version": gcc_version_str.replace(" ", "_"),
"hostname": socket.gethostname(),
}
compiledir_format_dict["short_platform"] = short_platform()
# Allow to have easily one compiledir per device. # Allow to have easily one compiledir per device.
compiledir_format_dict["device"] = config.device _compiledir_format_dict["device"] = config.device
compiledir_format_keys = ", ".join(sorted(compiledir_format_dict.keys())) compiledir_format_keys = ", ".join(sorted(_compiledir_format_dict.keys()))
default_compiledir_format = ( _default_compiledir_format = (
"compiledir_%(short_platform)s-%(processor)s-" "compiledir_%(short_platform)s-%(processor)s-"
"%(python_version)s-%(python_bitwidth)s" "%(python_version)s-%(python_bitwidth)s"
) )
...@@ -1725,83 +1803,13 @@ def add_caching_dir_configvars(): ...@@ -1725,83 +1803,13 @@ def add_caching_dir_configvars():
f"""\ f"""\
Format string for platform-dependent compiled Format string for platform-dependent compiled
module subdirectory (relative to base_compiledir). module subdirectory (relative to base_compiledir).
Available keys: {compiledir_format_keys}. Defaults to {default_compiledir_format}. Available keys: {compiledir_format_keys}. Defaults to {_default_compiledir_format}.
""" """
) )
), ),
StrParam(default_compiledir_format, mutable=False), StrParam(_default_compiledir_format, mutable=False),
in_c_key=False, in_c_key=False,
) )
def default_compiledirname():
formatted = theano.config.compiledir_format % compiledir_format_dict
safe = re.sub(r"[\(\)\s,]+", "_", formatted)
return safe
def filter_base_compiledir(path):
# Expand '~' in path
return os.path.expanduser(str(path))
def filter_compiledir(path):
# Expand '~' in path
path = os.path.expanduser(path)
# Turn path into the 'real' path. This ensures that:
# 1. There is no relative path, which would fail e.g. when trying to
# import modules from the compile dir.
# 2. The path is stable w.r.t. e.g. symlinks (which makes it easier
# to re-use compiled modules).
path = os.path.realpath(path)
if os.access(path, os.F_OK): # Do it exist?
if not os.access(path, os.R_OK | os.W_OK | os.X_OK):
# If it exist we need read, write and listing access
raise ValueError(
f"compiledir '{path}' exists but you don't have read, write"
" or listing permissions."
)
else:
try:
os.makedirs(path, 0o770) # read-write-execute for user and group
except OSError as e:
# Maybe another parallel execution of theano was trying to create
# the same directory at the same time.
if e.errno != errno.EEXIST:
raise ValueError(
"Unable to create the compiledir directory"
f" '{path}'. Check the permissions."
)
# PROBLEM: sometimes the initial approach based on
# os.system('touch') returned -1 for an unknown reason; the
# alternate approach here worked in all cases... it was weird.
# No error should happen as we checked the permissions.
init_file = os.path.join(path, "__init__.py")
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
else:
e.args += (f"{path} exist? {os.path.exists(path)}",)
raise
return path
def get_home_dir():
"""
Return location of the user's home directory.
"""
home = os.getenv("HOME")
if home is None:
# This expanduser usually works on Windows (see discussion on
# theano-users, July 13 2010).
home = os.path.expanduser("~")
if home == "~":
# This might happen when expanduser fails. Although the cause of
# failure is a mystery, it has been seen on some Windows system.
home = os.getenv("USERPROFILE")
assert home is not None
return home
# On Windows we should avoid writing temporary files to a directory that is # On Windows we should avoid writing temporary files to a directory that is
# part of the roaming part of the user profile. Instead we use the local part # part of the roaming part of the user profile. Instead we use the local part
...@@ -1809,24 +1817,21 @@ def add_caching_dir_configvars(): ...@@ -1809,24 +1817,21 @@ def add_caching_dir_configvars():
if sys.platform == "win32" and os.getenv("LOCALAPPDATA") is not None: if sys.platform == "win32" and os.getenv("LOCALAPPDATA") is not None:
default_base_compiledir = os.path.join(os.getenv("LOCALAPPDATA"), "Theano") default_base_compiledir = os.path.join(os.getenv("LOCALAPPDATA"), "Theano")
else: else:
default_base_compiledir = os.path.join(get_home_dir(), ".theano") default_base_compiledir = os.path.join(_get_home_dir(), ".theano")
config.add( config.add(
"base_compiledir", "base_compiledir",
"platform-independent root directory for compiled modules", "platform-independent root directory for compiled modules",
ConfigParam( ConfigParam(
default_base_compiledir, apply=filter_base_compiledir, mutable=False default_base_compiledir, apply=_filter_base_compiledir, mutable=False
), ),
in_c_key=False, in_c_key=False,
) )
def default_compiledir():
return os.path.join(theano.config.base_compiledir, default_compiledirname())
config.add( config.add(
"compiledir", "compiledir",
"platform-dependent cache directory for compiled modules", "platform-dependent cache directory for compiled modules",
ConfigParam(default_compiledir, apply=filter_compiledir, mutable=False), ConfigParam(_default_compiledir, apply=_filter_compiledir, mutable=False),
in_c_key=False, in_c_key=False,
) )
...@@ -1834,8 +1839,8 @@ def add_caching_dir_configvars(): ...@@ -1834,8 +1839,8 @@ def add_caching_dir_configvars():
"gpuarray__cache_path", "gpuarray__cache_path",
"Directory to cache pre-compiled kernels for the gpuarray backend.", "Directory to cache pre-compiled kernels for the gpuarray backend.",
ConfigParam( ConfigParam(
lambda: os.path.join(config.compiledir, "gpuarray_kernels"), _get_default_gpuarray__cache_path,
apply=filter_base_compiledir, apply=_filter_base_compiledir,
mutable=False, mutable=False,
), ),
in_c_key=False, in_c_key=False,
......
...@@ -23,13 +23,13 @@ import numpy.distutils ...@@ -23,13 +23,13 @@ import numpy.distutils
import theano import theano
from theano import config from theano import config
from theano.configdefaults import gcc_version_str, local_bitwidth from theano.configdefaults import gcc_version_str
# we will abuse the lockfile mechanism when reading and writing the registry # we will abuse the lockfile mechanism when reading and writing the registry
from theano.gof import compilelock from theano.gof import compilelock
from theano.gof.utils import flatten, hash_from_code from theano.gof.utils import flatten, hash_from_code
from theano.link.c.exceptions import MissingGXX from theano.link.c.exceptions import MissingGXX
from theano.utils import output_subprocess_Popen, subprocess_Popen from theano.utils import LOCAL_BITWIDTH, output_subprocess_Popen, subprocess_Popen
importlib = None importlib = None
...@@ -2308,7 +2308,7 @@ class GCC_compiler(Compiler): ...@@ -2308,7 +2308,7 @@ class GCC_compiler(Compiler):
if not any(["arm" in flag for flag in cxxflags]) and not any( if not any(["arm" in flag for flag in cxxflags]) and not any(
arch in platform.machine() for arch in ["arm", "aarch"] arch in platform.machine() for arch in ["arm", "aarch"]
): ):
n_bits = local_bitwidth() n_bits = LOCAL_BITWIDTH
cxxflags.append(f"-m{int(n_bits)}") cxxflags.append(f"-m{int(n_bits)}")
_logger.debug(f"Compiling for {n_bits} bit architecture") _logger.debug(f"Compiling for {n_bits} bit architecture")
...@@ -2317,7 +2317,7 @@ class GCC_compiler(Compiler): ...@@ -2317,7 +2317,7 @@ class GCC_compiler(Compiler):
# '-fPIC ignored for target (all code is position independent)' # '-fPIC ignored for target (all code is position independent)'
cxxflags.append("-fPIC") cxxflags.append("-fPIC")
if sys.platform == "win32" and local_bitwidth() == 64: if sys.platform == "win32" and LOCAL_BITWIDTH == 64:
# Under 64-bit Windows installation, sys.platform is 'win32'. # Under 64-bit Windows installation, sys.platform is 'win32'.
# We need to define MS_WIN64 for the preprocessor to be able to # We need to define MS_WIN64 for the preprocessor to be able to
# link with libpython. # link with libpython.
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import inspect import inspect
import os import os
import struct
import subprocess import subprocess
import sys import sys
import traceback import traceback
...@@ -21,12 +22,35 @@ __all__ = [ ...@@ -21,12 +22,35 @@ __all__ = [
"subprocess_Popen", "subprocess_Popen",
"call_subprocess_Popen", "call_subprocess_Popen",
"output_subprocess_Popen", "output_subprocess_Popen",
"LOCAL_BITWIDTH",
"PYTHON_INT_BITWIDTH",
] ]
__excepthooks = [] __excepthooks = []
LOCAL_BITWIDTH = struct.calcsize("P") * 8
"""
32 for 32bit arch, 64 for 64bit arch.
By "architecture", we mean the size of memory pointers (size_t in C),
*not* the size of long int, as it can be different.
Note that according to Python documentation, `platform.architecture()` is
not reliable on OS X with universal binaries.
Also, sys.maxsize does not exist in Python < 2.6.
'P' denotes a void*, and the size is expressed in bytes.
"""
PYTHON_INT_BITWIDTH = struct.calcsize("l") * 8
"""
The bit width of Python int (C long int).
Note that it can be different from the size of a memory pointer.
'l' denotes a C long int, and the size is expressed in bytes.
"""
def __call_excepthooks(type, value, trace): def __call_excepthooks(type, value, trace):
""" """
This function is meant to replace excepthook and do some This function is meant to replace excepthook and do some
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论