提交 1a369c4d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2321 from abergeron/fix_cache_access

Change the ModuleCache to only take the lock for actual compilation.
......@@ -1281,30 +1281,18 @@ class CLinker(link.Linker):
return ((), sig)
return version, sig
def compile_cmodule(self, location=None):
"""
Compile the module and return it.
"""
# Go through all steps of the compilation process.
for step_result in self.compile_cmodule_by_step(location=location):
pass
# And return the output of the last step, which should be the module
# itself.
return step_result
def get_src_code(self):
mod = self.get_dynamic_module()
return mod.code()
def compile_cmodule_by_step(self, location=None):
def compile_cmodule(self, location=None):
"""
This method is a callback for `ModuleCache.module_from_key`.
It is a generator (thus the 'by step'), so that:
- it first yields the module's C code
- it last yields the module itself
- it may yield other intermediate outputs in-between if needed
in the future (but this is not currently the case)
This compiles the source code for this linker and returns a
loaded module.
"""
if location is None:
location = cmodule.dlimport_workdir(config.compiledir)
mod = self.build_dynamic_module()
mod = self.get_dynamic_module()
c_compiler = self.c_compiler()
libs = self.libraries()
preargs = self.compile_args()
......@@ -1323,14 +1311,12 @@ class CLinker(link.Linker):
if 'amdlibm' in libs:
libs.remove('amdlibm')
src_code = mod.code()
yield src_code
get_lock()
try:
_logger.debug("LOCATION %s", str(location))
try:
module = c_compiler.compile_str(
module_name=mod.code_hash,
src_code=src_code,
src_code=mod.code(),
location=location,
include_dirs=self.header_dirs(),
lib_dirs=self.lib_dirs(),
......@@ -1341,13 +1327,16 @@ class CLinker(link.Linker):
raise
finally:
release_lock()
return module
yield module
def build_dynamic_module(self):
def get_dynamic_module(self):
"""Return a cmodule.DynamicModule instance full of the code
for our fgraph.
This method is cached on the first call so it can be called
multiple times without penalty.
"""
if not hasattr(self, '_mod'):
self.code_gen()
mod = cmodule.DynamicModule()
......@@ -1395,8 +1384,8 @@ class CLinker(link.Linker):
mod.add_include(header)
for init_code_block in self.init_code() + self.c_init_code_apply:
mod.add_init_code(init_code_block)
return mod
self._mod = mod
return self._mod
def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False):
......@@ -1420,7 +1409,7 @@ class CLinker(link.Linker):
module = self.compile_cmodule()
else:
module = get_module_cache().module_from_key(
key=key, fn=self.compile_cmodule_by_step, keep_lock=keep_lock)
key=key, lnk=self, keep_lock=keep_lock)
vars = self.inputs + self.outputs + self.orphans
# List of indices that should be ignored when passing the arguments
......
......@@ -617,7 +617,20 @@ class ModuleCache(object):
Older modules will be deleted in ``clear_old``.
"""
def refresh(self, age_thresh_use=None, delete_if_problem=False):
def _get_module(self, name):
"""
Fetch a compiled module from the loaded cache or the disk.
"""
if name not in self.module_from_name:
_logger.debug('loading name %s', name)
self.module_from_name[name] = dlimport(name)
self.stats[1] += 1
else:
_logger.debug('returning compiled module from cache %s', name)
self.stats[0] += 1
return self.module_from_name[name]
def refresh(self, age_thresh_use=None, delete_if_problem=False, cleanup=True):
"""Update cache data by walking the cache directory structure.
Load key.pkl files that have not been loaded yet.
......@@ -627,12 +640,15 @@ class ModuleCache(object):
:param age_thresh_use: Do not use modules olther than this.
Defaults to self.age_thresh_use.
:param delete_if_problem: If True, cache entries that meet one of those
two conditions are deleted:
- Those for which unpickling the KeyData file fails with an
unknown exception.
:param delete_if_problem: If True, cache entries that meet one
of those two conditions are deleted:
- Those for which unpickling the KeyData file fails with
an unknown exception.
- Duplicated modules, regardless of their age.
:param cleanup: Do a cleanup of the cache removing expired and
broken modules.
:returns: a list of modules of age higher than age_thresh_use.
"""
if age_thresh_use is None:
......@@ -640,8 +656,11 @@ class ModuleCache(object):
start_time = time.time()
too_old_to_use = []
compilelock.get_lock()
try:
to_delete = []
def rmtree(*args, **kwargs):
if cleanup:
to_delete.append((args, kwargs))
# add entries that are not in the entry_from_key dictionary
time_now = time.time()
# Go through directories in alphabetical order to ensure consistent
......@@ -655,9 +674,10 @@ class ModuleCache(object):
if not os.path.isdir(root):
continue
files = os.listdir(root)
if 'delete.me' in files or not files:
_rmtree(root, ignore_nocleanup=True,
if not files or 'delete.me' in files:
rmtree(root, ignore_nocleanup=True,
msg="delete.me found in dir")
continue
elif 'key.pkl' in files:
try:
entry = module_name_from_dir(root, files=files)
......@@ -669,7 +689,7 @@ class ModuleCache(object):
_logger.warning("ModuleCache.refresh() Found key "
"without dll in cache, deleting it. %s",
key_pkl)
_rmtree(root, ignore_nocleanup=True,
rmtree(root, ignore_nocleanup=True,
msg="missing module file", level=logging.INFO)
continue
if (time_now - last_access_time(entry)) < age_thresh_use:
......@@ -686,7 +706,7 @@ class ModuleCache(object):
# Happened once... not sure why (would be worth
# investigating if it ever happens again).
unpickle_failure()
_rmtree(root, ignore_nocleanup=True,
rmtree(root, ignore_nocleanup=True,
msg='broken cache directory [EOF]',
level=logging.WARNING)
continue
......@@ -699,7 +719,7 @@ class ModuleCache(object):
except Exception:
unpickle_failure()
if delete_if_problem:
_rmtree(root, ignore_nocleanup=True,
rmtree(root, ignore_nocleanup=True,
msg='broken cache directory',
level=logging.INFO)
else:
......@@ -708,7 +728,7 @@ class ModuleCache(object):
# not yet been imported (e.g. when running two
# different Theano-based scripts). They are not
# necessarily broken, but we cannot load them
# here.
# now. They will be loaded later if needed.
pass
continue
......@@ -719,7 +739,7 @@ class ModuleCache(object):
# do not know the config options that were used.
# As a result, we delete it instead (which is also
# simpler to implement).
_rmtree(root, ignore_nocleanup=True,
rmtree(root, ignore_nocleanup=True,
msg=(
'invalid cache entry format -- this '
'should not happen unless your cache '
......@@ -742,7 +762,7 @@ class ModuleCache(object):
key_data.key_pkl = key_pkl
else:
# This is suspicious. Better get rid of it.
_rmtree(root, ignore_nocleanup=True,
rmtree(root, ignore_nocleanup=True,
msg='module file path mismatch',
level=logging.INFO)
continue
......@@ -761,7 +781,7 @@ class ModuleCache(object):
'Found a mix of unversioned and '
'versioned keys for the same '
'module %s', key_pkl)
_rmtree(root, ignore_nocleanup=True,
rmtree(root, ignore_nocleanup=True,
msg="unversioned key(s) in cache",
level=logging.INFO)
continue
......@@ -776,9 +796,10 @@ class ModuleCache(object):
# Note that it is important to walk through
# directories in alphabetical order so as to make
# sure all new processes only use the first one.
if cleanup:
age = time.time() - last_access_time(entry)
if delete_if_problem or age > self.age_thresh_del:
_rmtree(root, ignore_nocleanup=True,
rmtree(root, ignore_nocleanup=True,
msg='duplicated module',
level=logging.DEBUG)
else:
......@@ -869,156 +890,86 @@ class ModuleCache(object):
pkl_file_to_remove)
self.loaded_key_pkl.remove(pkl_file_to_remove)
finally:
compilelock.release_lock()
if to_delete:
with compilelock.lock_ctx():
for a, kw in to_delete:
_rmtree(*a, **kw)
_logger.debug('Time needed to refresh cache: %s',
(time.time() - start_time))
return too_old_to_use
def module_from_key(self, key, fn=None, keep_lock=False, key_data=None):
def _get_from_key(self, key, key_data=None):
"""
:param fn: A callable object that will return an iterable object when
called, such that the first element in this iterable object is the
source code of the module, and the last element is the module itself.
`fn` is called only if the key is not already in the cache, with
a single keyword argument `location` that is the path to the directory
where the module should be compiled.
:param key_data: If not None, it should be a KeyData object and the
key parameter should be None. In this case, we use the info from the
KeyData object to recover the module, rather than the key itself. Note
that this implies the module already exists (and may or may not have
already been loaded).
Returns a module if the passed-in key is found in the cache
and None otherwise.
May raise ValueError if the key is malformed.
"""
# We should only use one of the two ways to get a module.
assert key_data is None or key is None
rval = None
name = None
if key is not None:
assert key_data is None
try:
_version, _rest = key
except (TypeError, ValueError):
raise ValueError(
"Invalid key. key must have form (version, rest)", key)
name = None
if key is not None and key in self.entry_from_key:
# We have seen this key either in this process or previously.
if key in self.entry_from_key:
name = self.entry_from_key[key]
elif key_data is not None:
name = key_data.get_entry()
if name is not None:
# This is an existing module we can recover.
if name not in self.module_from_name:
_logger.debug('loading name %s', name)
self.module_from_name[name] = dlimport(name)
self.stats[1] += 1
else:
self.stats[0] += 1
_logger.debug('returning compiled module from cache %s', name)
rval = self.module_from_name[name]
else:
hash_key = hash(key)
key_data = None
# We have never seen this key before.
# We acquire the lock later only if we were able to
# generate C code. Otherwise, we would take the lock for ops
# that have only a perform().
lock_taken = False
# This try/finally block ensures that the lock is released once we
# are done writing in the cache file or after raising an exception.
try:
# Embedding two try statements for Python 2.4 compatibility
# (cannot do try / except / finally).
try:
location = dlimport_workdir(self.dirname)
except OSError, e:
_logger.error(e)
if e.errno == 31:
_logger.error('There are %i files in %s',
len(os.listdir(config.compiledir)),
config.compiledir)
raise
try:
compile_steps = fn(location=location).__iter__()
# Check if we already know a module with the same hash.
# If we do, then there is no need to even compile it.
duplicated_module = False
# The first compilation step is to yield the source code.
src_code = next(compile_steps)
module_hash = get_module_hash(src_code, key)
# The op has c_code, so take the lock.
compilelock.get_lock()
lock_taken = True
if not os.path.exists(location):
# Temporary fix, we should make sure it don't
# get deleted by the clear*() fct.
os.makedirs(location)
assert key_data is not None
name = key_data.get_entry()
if name is None:
return None
return self._get_module(name)
def _get_from_hash(self, module_hash, key, keep_lock=False):
if module_hash in self.module_hash_to_key_data:
_logger.debug("Duplicated module! Will re-use the "
"previous one")
duplicated_module = True
# Load the already existing module.
key_data = self.module_hash_to_key_data[module_hash]
# Note that we do not pass the `fn` argument, since it
# should not be used considering that the module should
# already be compiled.
module = self.module_from_key(key=None,
key_data=key_data)
name = module.__file__
# Add current key to the set of keys associated to the
# same module. We only save the KeyData object of
# versioned modules.
module = self._get_from_key(None, key_data)
with compilelock.lock_ctx(keep_lock=keep_lock):
try:
key_data.add_key(key, save_pkl=bool(_version))
key_data.add_key(key, save_pkl=bool(key[0]))
key_broken = False
except cPickle.PicklingError:
# This should only happen if we tried to save the
# pickled file.
assert _version
# The key we are trying to add is broken: we will
# not add it after all.
key_data.remove_key(key)
key_broken = True
if (_version and not key_broken and
if (key[0] and not key_broken and
self.check_for_broken_eq):
self.check_key(key, key_data.key_pkl)
self._update_mappings(key, key_data, module.__file__)
return module
else:
return None
# We can delete the work directory.
_rmtree(location, ignore_nocleanup=True,
msg='temporary workdir of duplicated module')
def _update_mappings(self, key, key_data, name):
all_keys = key_data.keys
if not all_keys:
all_keys = [key]
assert key in all_keys
for k in all_keys:
if k in self.entry_from_key:
assert self.entry_from_key[k] == name
else:
# Will fail if there is an error compiling the C code.
# The exception will be caught and the work dir will be
# deleted.
while True:
try:
# The module should be returned by the last
# step of the compilation.
module = next(compile_steps)
except StopIteration:
break
self.entry_from_key[k] = name
if key[0]:
self.similar_keys.setdefault(get_safe_part(k),
[]).append(key)
# Obtain path to the '.so' module file.
def _add_to_cache(self, module, key, module_hash):
"""
This function expects the compile lock to be held.
"""
name = module.__file__
_logger.debug("Adding module to cache %s %s",
key, name)
assert name.startswith(location)
assert name not in self.module_from_name
# Changing the hash of the key is not allowed during
# compilation. That is the only cause found that makes
# the following assert fail.
assert hash(key) == hash_key
assert key not in self.entry_from_key
location = os.path.dirname(name)
key_pkl = os.path.join(location, 'key.pkl')
assert not os.path.exists(key_pkl)
key_data = KeyData(
......@@ -1027,93 +978,103 @@ class ModuleCache(object):
key_pkl=key_pkl,
entry=name)
# Note that we only save KeyData objects associated to
# versioned modules. So for unversioned key, the
# `key_pkl` field of the KeyData object will be a
# non-existing file (which does not matter since it
# will not be accessed).
if _version:
if key[0]:
try:
key_data.save_pkl()
key_broken = False
except cPickle.PicklingError:
key_broken = True
# Remove key from the KeyData object, to make
# sure we never try to save it again.
# We still keep the KeyData object and save it
# so that the module can be re-used in the
# future.
key_data.keys = set()
key_data.remove_key(key)
key_data.save_pkl()
if not key_broken and self.check_for_broken_eq:
self.check_key(key, key_pkl)
# Adding the KeyData file to this set means it is a
# versioned module.
self.loaded_key_pkl.add(key_pkl)
elif config.cmodule.warn_no_version:
key_flat = flatten(key)
ops = [k for k in key_flat
if isinstance(k, theano.Op)]
ops = [k for k in key_flat if isinstance(k, theano.Op)]
_logger.warning("not all the"
" following op(s) implement"
" c_code_cache_version(). This makes them"
" recompiled for each process." + str(ops))
self._update_mappings(key, key_data, module.__file__)
return key_data
# Map the new module to its KeyData object. Note that
# we need to do it regardless of whether the key is
# versioned or not if we want to be able to re-use this
# module inside the same process.
self.module_hash_to_key_data[module_hash] = key_data
def module_from_key(self, key, lnk=None, keep_lock=False):
"""
Return a module from the cache, compiling it if necessary.
except Exception:
# This may happen e.g. when an Op has no C implementation.
# In any case, we do not want to keep around the temporary
# work directory, as it may cause trouble if we create too
# many of these. The 'ignore_if_missing' flag is set just
# in case this directory would have already been deleted.
_rmtree(location, ignore_if_missing=True,
msg=('exception -- '
'typically means no C implementation'))
raise
:param key: The key object associated with the module. If this
hits a match, we avoid compilation.
finally:
# Release lock if needed.
if not keep_lock and lock_taken:
compilelock.release_lock()
:param lnk: Usually a CLinker instance, but it can be any
object that defines the `get_src_code()` and
`compile_cmodule(location)` functions. The first
one returns the source code of the module to
load/compile and the second performs the actual
compilation.
# Update map from key to module name for all keys associated to
# this same module.
all_keys = key_data.keys
if not all_keys:
# Should only happen for broken keys.
assert key_broken
all_keys = [key]
else:
assert key in key_data.keys
for k in all_keys:
if k in self.entry_from_key:
# If we had already seen this key, then it should be
# associated to the same module.
assert self.entry_from_key[k] == name
else:
self.entry_from_key[k] = name
if _version:
self.similar_keys.setdefault(get_safe_part(k),
[]).append(key)
:param keep_lock: If True, the compilation lock will not be
released if taken.
"""
# Is the module in the cache?
module = self._get_from_key(key)
if module is not None:
return module
if name in self.module_from_name:
# May happen if we are re-using an existing module.
assert duplicated_module
assert self.module_from_name[name] is module
else:
lock_taken = False
src_code = lnk.get_src_code()
# Is the source code already in the cache?
module_hash = get_module_hash(src_code, key)
module = self._get_from_hash(module_hash, key, keep_lock=keep_lock)
if module is not None:
return module
with compilelock.lock_ctx(keep_lock=keep_lock):
# Maybe somebody else compiled it for us while we
# where waiting for the lock. Try to load it again
self.refresh(cleanup=False)
module = self._get_from_key(key)
if module is not None:
return module
module = self._get_from_hash(module_hash, key)
if module is not None:
return module
hash_key = hash(key)
nocleanup = False
try:
location = dlimport_workdir(self.dirname)
module = lnk.compile_cmodule(location)
name = module.__file__
assert name.startswith(location)
assert name not in self.module_from_name
self.module_from_name[name] = module
nocleanup = True
except OSError, e:
_logger.error(e)
if e.errno == 31:
_logger.error('There are %i files in %s',
len(os.listdir(config.compiledir)),
config.compiledir)
raise
finally:
if not nocleanup:
_rmtree(location, ignore_if_missing=True,
msg='exception during compilation')
# Changing the hash of the key is not allowed during
# compilation.
assert hash(key) == hash_key
key_data = self._add_to_cache(module, key, module_hash)
self.module_hash_to_key_data[module_hash] = key_data
self.stats[2] += 1
rval = module
#_logger.debug('stats %s %i', self.stats, sum(self.stats))
return rval
return module
def check_key(self, key, key_pkl):
"""
......@@ -1193,8 +1154,7 @@ class ModuleCache(object):
else:
age_thresh_use = None
compilelock.get_lock()
try:
with compilelock.lock_ctx():
# Update the age of modules that have been accessed by other
# processes and get all module that are too old to use
# (not loaded in self.entry_from_key).
......@@ -1213,9 +1173,6 @@ class ModuleCache(object):
_rmtree(parent, msg='old cache directory', level=logging.INFO,
ignore_nocleanup=True)
finally:
compilelock.release_lock()
def clear(self, unversioned_min_age=None, clear_base_files=False,
delete_if_problem=False):
"""
......@@ -1232,16 +1189,13 @@ class ModuleCache(object):
:param delete_if_problem: See help of refresh() method.
"""
compilelock.get_lock()
try:
with compilelock.lock_ctx():
self.clear_old(
age_thresh_del=-1.0,
delete_if_problem=delete_if_problem)
self.clear_unversioned(min_age=unversioned_min_age)
if clear_base_files:
self.clear_base_files()
finally:
compilelock.release_lock()
def clear_base_files(self):
"""
......@@ -1253,8 +1207,7 @@ class ModuleCache(object):
rename them with the '.delete.me' extension, to mark them to be deleted
next time we clear the cache.
"""
compilelock.get_lock()
try:
with compilelock.lock_ctx():
for base_dir in ('cuda_ndarray', 'cutils_ext', 'lazylinker_ext',
'scan_perform'):
to_delete = os.path.join(self.dirname, base_dir + '.delete.me')
......@@ -1272,8 +1225,6 @@ class ModuleCache(object):
except Exception:
_logger.warning('Could not move %s to %s',
to_rename, to_delete)
finally:
compilelock.release_lock()
def clear_unversioned(self, min_age=None):
"""
......@@ -1288,9 +1239,8 @@ class ModuleCache(object):
if min_age is None:
min_age = self.age_thresh_del_unversioned
compilelock.get_lock()
with compilelock.lock_ctx():
all_key_datas = self.module_hash_to_key_data.values()
try:
for key_data in all_key_datas:
if not key_data.keys:
# May happen for broken versioned keys.
......@@ -1363,17 +1313,12 @@ class ModuleCache(object):
_rmtree(os.path.join(self.dirname, filename),
msg='old unversioned', level=logging.INFO,
ignore_nocleanup=True)
finally:
compilelock.release_lock()
def _on_atexit(self):
# Note: no need to call refresh() since it is called by clear_old().
compilelock.get_lock()
try:
with compilelock.lock_ctx():
self.clear_old()
self.clear_unversioned()
finally:
compilelock.release_lock()
_logger.debug('Time spent checking keys: %s',
self.time_spent_in_check_key)
......
......@@ -8,6 +8,8 @@ import socket # only used for gethostname()
import time
import logging
from contextlib import contextmanager
from theano import config
from theano.configparser import AddConfigVar, IntParam
......@@ -44,6 +46,14 @@ def force_unlock():
release_lock()
@contextmanager
def lock_ctx(lock_dir=None, keep_lock=False, **kw):
get_lock(lock_dir=lock_dir, **kw)
yield
if not keep_lock:
release_lock()
def get_lock(lock_dir=None, **kw):
"""
Obtain lock on compilation directory.
......
File mode changed from 100755 to 100644
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论