提交 1a369c4d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2321 from abergeron/fix_cache_access

Change the ModuleCache to only take the lock for actual compilation.
...@@ -1281,30 +1281,18 @@ class CLinker(link.Linker): ...@@ -1281,30 +1281,18 @@ class CLinker(link.Linker):
return ((), sig) return ((), sig)
return version, sig return version, sig
def get_src_code(self):
mod = self.get_dynamic_module()
return mod.code()
def compile_cmodule(self, location=None): def compile_cmodule(self, location=None):
""" """
Compile the module and return it. This compiles the source code for this linker and returns a
""" loaded module.
# Go through all steps of the compilation process.
for step_result in self.compile_cmodule_by_step(location=location):
pass
# And return the output of the last step, which should be the module
# itself.
return step_result
def compile_cmodule_by_step(self, location=None):
"""
This method is a callback for `ModuleCache.module_from_key`.
It is a generator (thus the 'by step'), so that:
- it first yields the module's C code
- it last yields the module itself
- it may yield other intermediate outputs in-between if needed
in the future (but this is not currently the case)
""" """
if location is None: if location is None:
location = cmodule.dlimport_workdir(config.compiledir) location = cmodule.dlimport_workdir(config.compiledir)
mod = self.build_dynamic_module() mod = self.get_dynamic_module()
c_compiler = self.c_compiler() c_compiler = self.c_compiler()
libs = self.libraries() libs = self.libraries()
preargs = self.compile_args() preargs = self.compile_args()
...@@ -1323,48 +1311,49 @@ class CLinker(link.Linker): ...@@ -1323,48 +1311,49 @@ class CLinker(link.Linker):
if 'amdlibm' in libs: if 'amdlibm' in libs:
libs.remove('amdlibm') libs.remove('amdlibm')
src_code = mod.code() src_code = mod.code()
yield src_code
get_lock() get_lock()
try: try:
_logger.debug("LOCATION %s", str(location)) _logger.debug("LOCATION %s", str(location))
try: module = c_compiler.compile_str(
module = c_compiler.compile_str( module_name=mod.code_hash,
module_name=mod.code_hash, src_code=mod.code(),
src_code=src_code, location=location,
location=location, include_dirs=self.header_dirs(),
include_dirs=self.header_dirs(), lib_dirs=self.lib_dirs(),
lib_dirs=self.lib_dirs(), libs=libs,
libs=libs, preargs=preargs)
preargs=preargs) except Exception, e:
except Exception, e: e.args += (str(self.fgraph),)
e.args += (str(self.fgraph),) raise
raise
finally: finally:
release_lock() release_lock()
return module
yield module def get_dynamic_module(self):
def build_dynamic_module(self):
"""Return a cmodule.DynamicModule instance full of the code """Return a cmodule.DynamicModule instance full of the code
for our fgraph. for our fgraph.
This method is cached on the first call so it can be called
multiple times without penalty.
""" """
self.code_gen() if not hasattr(self, '_mod'):
self.code_gen()
mod = cmodule.DynamicModule() mod = cmodule.DynamicModule()
# The code of instantiate # The code of instantiate
# the 1 is for error_storage # the 1 is for error_storage
code = self.instantiate_code(1 + len(self.args)) code = self.instantiate_code(1 + len(self.args))
instantiate = cmodule.ExtFunction('instantiate', code, instantiate = cmodule.ExtFunction('instantiate', code,
method=cmodule.METH_VARARGS) method=cmodule.METH_VARARGS)
#['error_storage'] + argnames, #['error_storage'] + argnames,
#local_dict = d, #local_dict = d,
#global_dict = {}) #global_dict = {})
# Static methods that can run and destroy the struct built by # Static methods that can run and destroy the struct built by
# instantiate. # instantiate.
if PY3: if PY3:
static = """ static = """
static int {struct_name}_executor({struct_name} *self) {{ static int {struct_name}_executor({struct_name} *self) {{
return self->run(); return self->run();
}} }}
...@@ -1374,8 +1363,8 @@ class CLinker(link.Linker): ...@@ -1374,8 +1363,8 @@ class CLinker(link.Linker):
delete self; delete self;
}} }}
""".format(struct_name=self.struct_name) """.format(struct_name=self.struct_name)
else: else:
static = """ static = """
static int %(struct_name)s_executor(%(struct_name)s* self) { static int %(struct_name)s_executor(%(struct_name)s* self) {
return self->run(); return self->run();
} }
...@@ -1386,17 +1375,17 @@ class CLinker(link.Linker): ...@@ -1386,17 +1375,17 @@ class CLinker(link.Linker):
""" % dict(struct_name=self.struct_name) """ % dict(struct_name=self.struct_name)
# We add all the support code, compile args, headers and libs we need. # We add all the support code, compile args, headers and libs we need.
for support_code in self.support_code() + self.c_support_code_apply: for support_code in self.support_code() + self.c_support_code_apply:
mod.add_support_code(support_code) mod.add_support_code(support_code)
mod.add_support_code(self.struct_code) mod.add_support_code(self.struct_code)
mod.add_support_code(static) mod.add_support_code(static)
mod.add_function(instantiate) mod.add_function(instantiate)
for header in self.headers(): for header in self.headers():
mod.add_include(header) mod.add_include(header)
for init_code_block in self.init_code() + self.c_init_code_apply: for init_code_block in self.init_code() + self.c_init_code_apply:
mod.add_init_code(init_code_block) mod.add_init_code(init_code_block)
self._mod = mod
return mod return self._mod
def cthunk_factory(self, error_storage, in_storage, out_storage, def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False): keep_lock=False):
...@@ -1420,7 +1409,7 @@ class CLinker(link.Linker): ...@@ -1420,7 +1409,7 @@ class CLinker(link.Linker):
module = self.compile_cmodule() module = self.compile_cmodule()
else: else:
module = get_module_cache().module_from_key( module = get_module_cache().module_from_key(
key=key, fn=self.compile_cmodule_by_step, keep_lock=keep_lock) key=key, lnk=self, keep_lock=keep_lock)
vars = self.inputs + self.outputs + self.orphans vars = self.inputs + self.outputs + self.orphans
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
......
...@@ -617,7 +617,20 @@ class ModuleCache(object): ...@@ -617,7 +617,20 @@ class ModuleCache(object):
Older modules will be deleted in ``clear_old``. Older modules will be deleted in ``clear_old``.
""" """
def refresh(self, age_thresh_use=None, delete_if_problem=False): def _get_module(self, name):
"""
Fetch a compiled module from the loaded cache or the disk.
"""
if name not in self.module_from_name:
_logger.debug('loading name %s', name)
self.module_from_name[name] = dlimport(name)
self.stats[1] += 1
else:
_logger.debug('returning compiled module from cache %s', name)
self.stats[0] += 1
return self.module_from_name[name]
def refresh(self, age_thresh_use=None, delete_if_problem=False, cleanup=True):
"""Update cache data by walking the cache directory structure. """Update cache data by walking the cache directory structure.
Load key.pkl files that have not been loaded yet. Load key.pkl files that have not been loaded yet.
...@@ -627,12 +640,15 @@ class ModuleCache(object): ...@@ -627,12 +640,15 @@ class ModuleCache(object):
:param age_thresh_use: Do not use modules olther than this. :param age_thresh_use: Do not use modules olther than this.
Defaults to self.age_thresh_use. Defaults to self.age_thresh_use.
:param delete_if_problem: If True, cache entries that meet one of those :param delete_if_problem: If True, cache entries that meet one
two conditions are deleted: of those two conditions are deleted:
- Those for which unpickling the KeyData file fails with an - Those for which unpickling the KeyData file fails with
unknown exception. an unknown exception.
- Duplicated modules, regardless of their age. - Duplicated modules, regardless of their age.
:param cleanup: Do a cleanup of the cache removing expired and
broken modules.
:returns: a list of modules of age higher than age_thresh_use. :returns: a list of modules of age higher than age_thresh_use.
""" """
if age_thresh_use is None: if age_thresh_use is None:
...@@ -640,480 +656,425 @@ class ModuleCache(object): ...@@ -640,480 +656,425 @@ class ModuleCache(object):
start_time = time.time() start_time = time.time()
too_old_to_use = [] too_old_to_use = []
compilelock.get_lock() to_delete = []
try: def rmtree(*args, **kwargs):
# add entries that are not in the entry_from_key dictionary if cleanup:
time_now = time.time() to_delete.append((args, kwargs))
# Go through directories in alphabetical order to ensure consistent
# behavior. # add entries that are not in the entry_from_key dictionary
subdirs = sorted(os.listdir(self.dirname)) time_now = time.time()
for root in subdirs: # Go through directories in alphabetical order to ensure consistent
root = os.path.join(self.dirname, root) # behavior.
key_pkl = os.path.join(root, 'key.pkl') subdirs = sorted(os.listdir(self.dirname))
if key_pkl in self.loaded_key_pkl: for root in subdirs:
continue root = os.path.join(self.dirname, root)
if not os.path.isdir(root): key_pkl = os.path.join(root, 'key.pkl')
if key_pkl in self.loaded_key_pkl:
continue
if not os.path.isdir(root):
continue
files = os.listdir(root)
if not files or 'delete.me' in files:
rmtree(root, ignore_nocleanup=True,
msg="delete.me found in dir")
continue
elif 'key.pkl' in files:
try:
entry = module_name_from_dir(root, files=files)
except ValueError: # there is a key but no dll!
if not root.startswith("/tmp"):
# Under /tmp, file are removed periodically by the
# os. So it is normal that this happens from time
# to time.
_logger.warning("ModuleCache.refresh() Found key "
"without dll in cache, deleting it. %s",
key_pkl)
rmtree(root, ignore_nocleanup=True,
msg="missing module file", level=logging.INFO)
continue continue
files = os.listdir(root) if (time_now - last_access_time(entry)) < age_thresh_use:
if 'delete.me' in files or not files: _logger.debug('refresh adding %s', key_pkl)
_rmtree(root, ignore_nocleanup=True,
msg="delete.me found in dir") def unpickle_failure():
elif 'key.pkl' in files: _logger.info("ModuleCache.refresh() Failed to "
"unpickle cache file %s", key_pkl)
try: try:
entry = module_name_from_dir(root, files=files) with open(key_pkl, 'rb') as f:
except ValueError: # there is a key but no dll! key_data = cPickle.load(f)
if not root.startswith("/tmp"): except EOFError:
# Under /tmp, file are removed periodically by the # Happened once... not sure why (would be worth
# os. So it is normal that this happens from time # investigating if it ever happens again).
# to time. unpickle_failure()
_logger.warning("ModuleCache.refresh() Found key " rmtree(root, ignore_nocleanup=True,
"without dll in cache, deleting it. %s", msg='broken cache directory [EOF]',
key_pkl) level=logging.WARNING)
_rmtree(root, ignore_nocleanup=True, continue
msg="missing module file", level=logging.INFO) except ValueError:
# This can happen when we have bad config value
# in the cuda.nvcc_compiler.py file.
# We should not hide it here, as this will cause
# an unrelated error to appear.
raise
except Exception:
unpickle_failure()
if delete_if_problem:
rmtree(root, ignore_nocleanup=True,
msg='broken cache directory',
level=logging.INFO)
else:
# This exception is often triggered by keys
# that contain references to classes that have
# not yet been imported (e.g. when running two
# different Theano-based scripts). They are not
# necessarily broken, but we cannot load them
# now. They will be loaded later if needed.
pass
continue continue
if (time_now - last_access_time(entry)) < age_thresh_use:
_logger.debug('refresh adding %s', key_pkl)
def unpickle_failure():
_logger.info("ModuleCache.refresh() Failed to "
"unpickle cache file %s", key_pkl)
try:
with open(key_pkl, 'rb') as f:
key_data = cPickle.load(f)
except EOFError:
# Happened once... not sure why (would be worth
# investigating if it ever happens again).
unpickle_failure()
_rmtree(root, ignore_nocleanup=True,
msg='broken cache directory [EOF]',
level=logging.WARNING)
continue
except ValueError:
# This can happen when we have bad config value
# in the cuda.nvcc_compiler.py file.
# We should not hide it here, as this will cause
# an unrelated error to appear.
raise
except Exception:
unpickle_failure()
if delete_if_problem:
_rmtree(root, ignore_nocleanup=True,
msg='broken cache directory',
level=logging.INFO)
else:
# This exception is often triggered by keys
# that contain references to classes that have
# not yet been imported (e.g. when running two
# different Theano-based scripts). They are not
# necessarily broken, but we cannot load them
# here.
pass
continue
if not isinstance(key_data, KeyData): if not isinstance(key_data, KeyData):
# This is some old cache data, that does not fit # This is some old cache data, that does not fit
# the new cache format. It would be possible to # the new cache format. It would be possible to
# update it, but it is not entirely safe since we # update it, but it is not entirely safe since we
# do not know the config options that were used. # do not know the config options that were used.
# As a result, we delete it instead (which is also # As a result, we delete it instead (which is also
# simpler to implement). # simpler to implement).
_rmtree(root, ignore_nocleanup=True, rmtree(root, ignore_nocleanup=True,
msg=( msg=(
'invalid cache entry format -- this ' 'invalid cache entry format -- this '
'should not happen unless your cache ' 'should not happen unless your cache '
'was really old'), 'was really old'),
level=logging.WARN) level=logging.WARN)
continue continue
# Check the path to the module stored in the KeyData # Check the path to the module stored in the KeyData
# object matches the path to `entry`. There may be # object matches the path to `entry`. There may be
# a mismatch e.g. due to symlinks, or some directory # a mismatch e.g. due to symlinks, or some directory
# being renamed since last time cache was created. # being renamed since last time cache was created.
kd_entry = key_data.get_entry() kd_entry = key_data.get_entry()
if kd_entry != entry: if kd_entry != entry:
if is_same_entry(entry, kd_entry): if is_same_entry(entry, kd_entry):
# Update KeyData object. Note that we also need # Update KeyData object. Note that we also need
# to update the key_pkl field, because it is # to update the key_pkl field, because it is
# likely to be incorrect if the entry itself # likely to be incorrect if the entry itself
# was wrong. # was wrong.
key_data.entry = entry key_data.entry = entry
key_data.key_pkl = key_pkl key_data.key_pkl = key_pkl
else: else:
# This is suspicious. Better get rid of it. # This is suspicious. Better get rid of it.
_rmtree(root, ignore_nocleanup=True, rmtree(root, ignore_nocleanup=True,
msg='module file path mismatch', msg='module file path mismatch',
level=logging.INFO) level=logging.INFO)
continue continue
# Find unversioned keys from other processes. # Find unversioned keys from other processes.
# TODO: check if this can happen at all # TODO: check if this can happen at all
to_del = [key for key in key_data.keys if not key[0]] to_del = [key for key in key_data.keys if not key[0]]
if to_del: if to_del:
_logger.warning(
"ModuleCache.refresh() Found unversioned "
"key in cache, removing it. %s", key_pkl)
# Since the version is in the module hash, all
# keys should be unversioned.
if len(to_del) != len(key_data.keys):
_logger.warning( _logger.warning(
"ModuleCache.refresh() Found unversioned " 'Found a mix of unversioned and '
"key in cache, removing it. %s", key_pkl) 'versioned keys for the same '
# Since the version is in the module hash, all 'module %s', key_pkl)
# keys should be unversioned. rmtree(root, ignore_nocleanup=True,
if len(to_del) != len(key_data.keys): msg="unversioned key(s) in cache",
_logger.warning( level=logging.INFO)
'Found a mix of unversioned and ' continue
'versioned keys for the same '
'module %s', key_pkl)
_rmtree(root, ignore_nocleanup=True,
msg="unversioned key(s) in cache",
level=logging.INFO)
continue
mod_hash = key_data.module_hash mod_hash = key_data.module_hash
if mod_hash in self.module_hash_to_key_data: if mod_hash in self.module_hash_to_key_data:
# This may happen when two processes running # This may happen when two processes running
# simultaneously compiled the same module, one # simultaneously compiled the same module, one
# after the other. We delete one once it is old # after the other. We delete one once it is old
# enough (to be confident there is no other process # enough (to be confident there is no other process
# using it), or if `delete_if_problem` is True. # using it), or if `delete_if_problem` is True.
# Note that it is important to walk through # Note that it is important to walk through
# directories in alphabetical order so as to make # directories in alphabetical order so as to make
# sure all new processes only use the first one. # sure all new processes only use the first one.
if cleanup:
age = time.time() - last_access_time(entry) age = time.time() - last_access_time(entry)
if delete_if_problem or age > self.age_thresh_del: if delete_if_problem or age > self.age_thresh_del:
_rmtree(root, ignore_nocleanup=True, rmtree(root, ignore_nocleanup=True,
msg='duplicated module', msg='duplicated module',
level=logging.DEBUG) level=logging.DEBUG)
else: else:
_logger.debug('Found duplicated module not ' _logger.debug('Found duplicated module not '
'old enough yet to be deleted ' 'old enough yet to be deleted '
'(age: %s): %s', '(age: %s): %s',
age, entry) age, entry)
continue continue
# Remember the map from a module's hash to the KeyData # Remember the map from a module's hash to the KeyData
# object associated with it. # object associated with it.
self.module_hash_to_key_data[mod_hash] = key_data self.module_hash_to_key_data[mod_hash] = key_data
for key in key_data.keys: for key in key_data.keys:
if key not in self.entry_from_key: if key not in self.entry_from_key:
self.entry_from_key[key] = entry self.entry_from_key[key] = entry
# Assert that we have not already got this # Assert that we have not already got this
# entry somehow. # entry somehow.
assert entry not in self.module_from_name assert entry not in self.module_from_name
# Store safe part of versioned keys. # Store safe part of versioned keys.
if key[0]: if key[0]:
self.similar_keys.setdefault( self.similar_keys.setdefault(
get_safe_part(key), get_safe_part(key),
[]).append(key) []).append(key)
else: else:
_logger.warning( _logger.warning(
"The same cache key is associated to " "The same cache key is associated to "
"different modules (%s and %s). This " "different modules (%s and %s). This "
"is not supposed to happen! You may " "is not supposed to happen! You may "
"need to manually delete your cache " "need to manually delete your cache "
"directory to fix this.", "directory to fix this.",
self.entry_from_key[key], self.entry_from_key[key],
entry)
# Clean up the name space to prevent bug.
if key_data.keys:
del key
self.loaded_key_pkl.add(key_pkl)
else:
too_old_to_use.append(entry)
# If the compilation failed, no key.pkl is in that
# directory, but a mod.* should be there.
# We do nothing here.
# Clean up the name space to prevent bug.
del root, files, subdirs
# Remove entries that are not in the filesystem.
items_copy = list(self.module_hash_to_key_data.iteritems())
for module_hash, key_data in items_copy:
entry = key_data.get_entry()
try:
# Test to see that the file is [present and] readable.
open(entry).close()
gone = False
except IOError:
gone = True
if gone:
# Assert that we did not have one of the deleted files
# loaded up and in use.
# If so, it should not have been deleted. This should be
# considered a failure of the OTHER process, that deleted
# it.
if entry in self.module_from_name:
_logger.warning("A module that was loaded by this "
"ModuleCache can no longer be read from file "
"%s... this could lead to problems.",
entry) entry)
del self.module_from_name[entry] # Clean up the name space to prevent bug.
if key_data.keys:
_logger.info("deleting ModuleCache entry %s", entry) del key
key_data.delete_keys_from(self.entry_from_key) self.loaded_key_pkl.add(key_pkl)
del self.module_hash_to_key_data[module_hash] else:
if key_data.keys and list(key_data.keys)[0][0]: too_old_to_use.append(entry)
# this is a versioned entry, so should have been on
# disk. Something weird happened to cause this, so we # If the compilation failed, no key.pkl is in that
# are responding by printing a warning, removing # directory, but a mod.* should be there.
# evidence that we ever saw this mystery key. # We do nothing here.
pkl_file_to_remove = key_data.key_pkl
if not key_data.key_pkl.startswith("/tmp"): # Clean up the name space to prevent bug.
# Under /tmp, file are removed periodically by the del root, files, subdirs
# os. So it is normal that this happen from time to
# time. # Remove entries that are not in the filesystem.
_logger.warning("Removing key file %s because the " items_copy = list(self.module_hash_to_key_data.iteritems())
"corresponding module is gone from the " for module_hash, key_data in items_copy:
"file system.", entry = key_data.get_entry()
pkl_file_to_remove) try:
self.loaded_key_pkl.remove(pkl_file_to_remove) # Test to see that the file is [present and] readable.
open(entry).close()
finally: gone = False
compilelock.release_lock() except IOError:
gone = True
if gone:
# Assert that we did not have one of the deleted files
# loaded up and in use.
# If so, it should not have been deleted. This should be
# considered a failure of the OTHER process, that deleted
# it.
if entry in self.module_from_name:
_logger.warning("A module that was loaded by this "
"ModuleCache can no longer be read from file "
"%s... this could lead to problems.",
entry)
del self.module_from_name[entry]
_logger.info("deleting ModuleCache entry %s", entry)
key_data.delete_keys_from(self.entry_from_key)
del self.module_hash_to_key_data[module_hash]
if key_data.keys and list(key_data.keys)[0][0]:
# this is a versioned entry, so should have been on
# disk. Something weird happened to cause this, so we
# are responding by printing a warning, removing
# evidence that we ever saw this mystery key.
pkl_file_to_remove = key_data.key_pkl
if not key_data.key_pkl.startswith("/tmp"):
# Under /tmp, file are removed periodically by the
# os. So it is normal that this happen from time to
# time.
_logger.warning("Removing key file %s because the "
"corresponding module is gone from the "
"file system.",
pkl_file_to_remove)
self.loaded_key_pkl.remove(pkl_file_to_remove)
if to_delete:
with compilelock.lock_ctx():
for a, kw in to_delete:
_rmtree(*a, **kw)
_logger.debug('Time needed to refresh cache: %s', _logger.debug('Time needed to refresh cache: %s',
(time.time() - start_time)) (time.time() - start_time))
return too_old_to_use return too_old_to_use
def module_from_key(self, key, fn=None, keep_lock=False, key_data=None): def _get_from_key(self, key, key_data=None):
""" """
:param fn: A callable object that will return an iterable object when Returns a module if the passed-in key is found in the cache
called, such that the first element in this iterable object is the and None otherwise.
source code of the module, and the last element is the module itself.
`fn` is called only if the key is not already in the cache, with May raise ValueError if the key is malformed.
a single keyword argument `location` that is the path to the directory
where the module should be compiled.
:param key_data: If not None, it should be a KeyData object and the
key parameter should be None. In this case, we use the info from the
KeyData object to recover the module, rather than the key itself. Note
that this implies the module already exists (and may or may not have
already been loaded).
""" """
# We should only use one of the two ways to get a module. name = None
assert key_data is None or key is None
rval = None
if key is not None: if key is not None:
assert key_data is None
try: try:
_version, _rest = key _version, _rest = key
except (TypeError, ValueError): except (TypeError, ValueError):
raise ValueError( raise ValueError(
"Invalid key. key must have form (version, rest)", key) "Invalid key. key must have form (version, rest)", key)
name = None if key in self.entry_from_key:
if key is not None and key in self.entry_from_key: name = self.entry_from_key[key]
# We have seen this key either in this process or previously. else:
name = self.entry_from_key[key] assert key_data is not None
elif key_data is not None:
name = key_data.get_entry() name = key_data.get_entry()
if name is not None: if name is None:
# This is an existing module we can recover. return None
if name not in self.module_from_name: return self._get_module(name)
_logger.debug('loading name %s', name)
self.module_from_name[name] = dlimport(name) def _get_from_hash(self, module_hash, key, keep_lock=False):
self.stats[1] += 1 if module_hash in self.module_hash_to_key_data:
else: key_data = self.module_hash_to_key_data[module_hash]
self.stats[0] += 1 module = self._get_from_key(None, key_data)
_logger.debug('returning compiled module from cache %s', name) with compilelock.lock_ctx(keep_lock=keep_lock):
rval = self.module_from_name[name] try:
key_data.add_key(key, save_pkl=bool(key[0]))
key_broken = False
except cPickle.PicklingError:
key_data.remove_key(key)
key_broken = True
if (key[0] and not key_broken and
self.check_for_broken_eq):
self.check_key(key, key_data.key_pkl)
self._update_mappings(key, key_data, module.__file__)
return module
else: else:
hash_key = hash(key) return None
key_data = None
# We have never seen this key before. def _update_mappings(self, key, key_data, name):
all_keys = key_data.keys
# We acquire the lock later only if we were able to if not all_keys:
# generate C code. Otherwise, we would take the lock for ops all_keys = [key]
# that have only a perform(). assert key in all_keys
lock_taken = False for k in all_keys:
# This try/finally block ensures that the lock is released once we if k in self.entry_from_key:
# are done writing in the cache file or after raising an exception. assert self.entry_from_key[k] == name
else:
self.entry_from_key[k] = name
if key[0]:
self.similar_keys.setdefault(get_safe_part(k),
[]).append(key)
def _add_to_cache(self, module, key, module_hash):
"""
This function expects the compile lock to be held.
"""
name = module.__file__
_logger.debug("Adding module to cache %s %s",
key, name)
# Changing the hash of the key is not allowed during
# compilation. That is the only cause found that makes
# the following assert fail.
assert key not in self.entry_from_key
location = os.path.dirname(name)
key_pkl = os.path.join(location, 'key.pkl')
assert not os.path.exists(key_pkl)
key_data = KeyData(
keys=set([key]),
module_hash=module_hash,
key_pkl=key_pkl,
entry=name)
if key[0]:
try: try:
# Embedding two try statements for Python 2.4 compatibility key_data.save_pkl()
# (cannot do try / except / finally). key_broken = False
try: except cPickle.PicklingError:
location = dlimport_workdir(self.dirname) key_broken = True
except OSError, e: key_data.remove_key(key)
_logger.error(e) key_data.save_pkl()
if e.errno == 31: if not key_broken and self.check_for_broken_eq:
_logger.error('There are %i files in %s', self.check_key(key, key_pkl)
len(os.listdir(config.compiledir)), self.loaded_key_pkl.add(key_pkl)
config.compiledir) elif config.cmodule.warn_no_version:
raise key_flat = flatten(key)
try: ops = [k for k in key_flat if isinstance(k, theano.Op)]
compile_steps = fn(location=location).__iter__() _logger.warning("not all the"
" following op(s) implement"
# Check if we already know a module with the same hash. " c_code_cache_version(). This makes them"
# If we do, then there is no need to even compile it. " recompiled for each process." + str(ops))
duplicated_module = False self._update_mappings(key, key_data, module.__file__)
# The first compilation step is to yield the source code. return key_data
src_code = next(compile_steps)
module_hash = get_module_hash(src_code, key) def module_from_key(self, key, lnk=None, keep_lock=False):
"""
# The op has c_code, so take the lock. Return a module from the cache, compiling it if necessary.
compilelock.get_lock()
lock_taken = True
if not os.path.exists(location):
# Temporary fix, we should make sure it don't
# get deleted by the clear*() fct.
os.makedirs(location)
if module_hash in self.module_hash_to_key_data:
_logger.debug("Duplicated module! Will re-use the "
"previous one")
duplicated_module = True
# Load the already existing module.
key_data = self.module_hash_to_key_data[module_hash]
# Note that we do not pass the `fn` argument, since it
# should not be used considering that the module should
# already be compiled.
module = self.module_from_key(key=None,
key_data=key_data)
name = module.__file__
# Add current key to the set of keys associated to the
# same module. We only save the KeyData object of
# versioned modules.
try:
key_data.add_key(key, save_pkl=bool(_version))
key_broken = False
except cPickle.PicklingError:
# This should only happen if we tried to save the
# pickled file.
assert _version
# The key we are trying to add is broken: we will
# not add it after all.
key_data.remove_key(key)
key_broken = True
if (_version and not key_broken and
self.check_for_broken_eq):
self.check_key(key, key_data.key_pkl)
# We can delete the work directory.
_rmtree(location, ignore_nocleanup=True,
msg='temporary workdir of duplicated module')
else:
# Will fail if there is an error compiling the C code.
# The exception will be caught and the work dir will be
# deleted.
while True:
try:
# The module should be returned by the last
# step of the compilation.
module = next(compile_steps)
except StopIteration:
break
# Obtain path to the '.so' module file.
name = module.__file__
_logger.debug("Adding module to cache %s %s",
key, name)
assert name.startswith(location)
assert name not in self.module_from_name
# Changing the hash of the key is not allowed during
# compilation. That is the only cause found that makes
# the following assert fail.
assert hash(key) == hash_key
assert key not in self.entry_from_key
key_pkl = os.path.join(location, 'key.pkl')
assert not os.path.exists(key_pkl)
key_data = KeyData(
keys=set([key]),
module_hash=module_hash,
key_pkl=key_pkl,
entry=name)
# Note that we only save KeyData objects associated to
# versioned modules. So for unversioned key, the
# `key_pkl` field of the KeyData object will be a
# non-existing file (which does not matter since it
# will not be accessed).
if _version:
try:
key_data.save_pkl()
key_broken = False
except cPickle.PicklingError:
key_broken = True
# Remove key from the KeyData object, to make
# sure we never try to save it again.
# We still keep the KeyData object and save it
# so that the module can be re-used in the
# future.
key_data.keys = set()
key_data.save_pkl()
if not key_broken and self.check_for_broken_eq:
self.check_key(key, key_pkl)
# Adding the KeyData file to this set means it is a
# versioned module.
self.loaded_key_pkl.add(key_pkl)
elif config.cmodule.warn_no_version:
key_flat = flatten(key)
ops = [k for k in key_flat
if isinstance(k, theano.Op)]
_logger.warning("not all the"
" following op(s) implement"
" c_code_cache_version(). This makes them"
" recompiled for each process." + str(ops))
# Map the new module to its KeyData object. Note that
# we need to do it regardless of whether the key is
# versioned or not if we want to be able to re-use this
# module inside the same process.
self.module_hash_to_key_data[module_hash] = key_data
except Exception: :param key: The key object associated with the module. If this
# This may happen e.g. when an Op has no C implementation. hits a match, we avoid compilation.
# In any case, we do not want to keep around the temporary
# work directory, as it may cause trouble if we create too
# many of these. The 'ignore_if_missing' flag is set just
# in case this directory would have already been deleted.
_rmtree(location, ignore_if_missing=True,
msg=('exception -- '
'typically means no C implementation'))
raise
finally: :param lnk: Usually a CLinker instance, but it can be any
# Release lock if needed. object that defines the `get_src_code()` and
if not keep_lock and lock_taken: `compile_cmodule(location)` functions. The first
compilelock.release_lock() one returns the source code of the module to
load/compile and the second performs the actual
# Update map from key to module name for all keys associated to compilation.
# this same module.
all_keys = key_data.keys :param keep_lock: If True, the compilation lock will not be
if not all_keys: released if taken.
# Should only happen for broken keys. """
assert key_broken # Is the module in the cache?
all_keys = [key] module = self._get_from_key(key)
else: if module is not None:
assert key in key_data.keys return module
for k in all_keys:
if k in self.entry_from_key: lock_taken = False
# If we had already seen this key, then it should be
# associated to the same module. src_code = lnk.get_src_code()
assert self.entry_from_key[k] == name # Is the source code already in the cache?
else: module_hash = get_module_hash(src_code, key)
self.entry_from_key[k] = name module = self._get_from_hash(module_hash, key, keep_lock=keep_lock)
if _version: if module is not None:
self.similar_keys.setdefault(get_safe_part(k), return module
[]).append(key)
with compilelock.lock_ctx(keep_lock=keep_lock):
if name in self.module_from_name: # Maybe somebody else compiled it for us while we
# May happen if we are re-using an existing module. # where waiting for the lock. Try to load it again
assert duplicated_module self.refresh(cleanup=False)
assert self.module_from_name[name] is module
else: module = self._get_from_key(key)
if module is not None:
return module
module = self._get_from_hash(module_hash, key)
if module is not None:
return module
hash_key = hash(key)
nocleanup = False
try:
location = dlimport_workdir(self.dirname)
module = lnk.compile_cmodule(location)
name = module.__file__
assert name.startswith(location)
assert name not in self.module_from_name
self.module_from_name[name] = module self.module_from_name[name] = module
nocleanup = True
except OSError, e:
_logger.error(e)
if e.errno == 31:
_logger.error('There are %i files in %s',
len(os.listdir(config.compiledir)),
config.compiledir)
raise
finally:
if not nocleanup:
_rmtree(location, ignore_if_missing=True,
msg='exception during compilation')
self.stats[2] += 1 # Changing the hash of the key is not allowed during
rval = module # compilation.
#_logger.debug('stats %s %i', self.stats, sum(self.stats)) assert hash(key) == hash_key
return rval
key_data = self._add_to_cache(module, key, module_hash)
self.module_hash_to_key_data[module_hash] = key_data
self.stats[2] += 1
return module
def check_key(self, key, key_pkl): def check_key(self, key, key_pkl):
""" """
...@@ -1193,8 +1154,7 @@ class ModuleCache(object): ...@@ -1193,8 +1154,7 @@ class ModuleCache(object):
else: else:
age_thresh_use = None age_thresh_use = None
compilelock.get_lock() with compilelock.lock_ctx():
try:
# Update the age of modules that have been accessed by other # Update the age of modules that have been accessed by other
# processes and get all module that are too old to use # processes and get all module that are too old to use
# (not loaded in self.entry_from_key). # (not loaded in self.entry_from_key).
...@@ -1213,9 +1173,6 @@ class ModuleCache(object): ...@@ -1213,9 +1173,6 @@ class ModuleCache(object):
_rmtree(parent, msg='old cache directory', level=logging.INFO, _rmtree(parent, msg='old cache directory', level=logging.INFO,
ignore_nocleanup=True) ignore_nocleanup=True)
finally:
compilelock.release_lock()
def clear(self, unversioned_min_age=None, clear_base_files=False, def clear(self, unversioned_min_age=None, clear_base_files=False,
delete_if_problem=False): delete_if_problem=False):
""" """
...@@ -1232,16 +1189,13 @@ class ModuleCache(object): ...@@ -1232,16 +1189,13 @@ class ModuleCache(object):
:param delete_if_problem: See help of refresh() method. :param delete_if_problem: See help of refresh() method.
""" """
compilelock.get_lock() with compilelock.lock_ctx():
try:
self.clear_old( self.clear_old(
age_thresh_del=-1.0, age_thresh_del=-1.0,
delete_if_problem=delete_if_problem) delete_if_problem=delete_if_problem)
self.clear_unversioned(min_age=unversioned_min_age) self.clear_unversioned(min_age=unversioned_min_age)
if clear_base_files: if clear_base_files:
self.clear_base_files() self.clear_base_files()
finally:
compilelock.release_lock()
def clear_base_files(self): def clear_base_files(self):
""" """
...@@ -1253,8 +1207,7 @@ class ModuleCache(object): ...@@ -1253,8 +1207,7 @@ class ModuleCache(object):
rename them with the '.delete.me' extension, to mark them to be deleted rename them with the '.delete.me' extension, to mark them to be deleted
next time we clear the cache. next time we clear the cache.
""" """
compilelock.get_lock() with compilelock.lock_ctx():
try:
for base_dir in ('cuda_ndarray', 'cutils_ext', 'lazylinker_ext', for base_dir in ('cuda_ndarray', 'cutils_ext', 'lazylinker_ext',
'scan_perform'): 'scan_perform'):
to_delete = os.path.join(self.dirname, base_dir + '.delete.me') to_delete = os.path.join(self.dirname, base_dir + '.delete.me')
...@@ -1272,8 +1225,6 @@ class ModuleCache(object): ...@@ -1272,8 +1225,6 @@ class ModuleCache(object):
except Exception: except Exception:
_logger.warning('Could not move %s to %s', _logger.warning('Could not move %s to %s',
to_rename, to_delete) to_rename, to_delete)
finally:
compilelock.release_lock()
def clear_unversioned(self, min_age=None): def clear_unversioned(self, min_age=None):
""" """
...@@ -1288,9 +1239,8 @@ class ModuleCache(object): ...@@ -1288,9 +1239,8 @@ class ModuleCache(object):
if min_age is None: if min_age is None:
min_age = self.age_thresh_del_unversioned min_age = self.age_thresh_del_unversioned
compilelock.get_lock() with compilelock.lock_ctx():
all_key_datas = self.module_hash_to_key_data.values() all_key_datas = self.module_hash_to_key_data.values()
try:
for key_data in all_key_datas: for key_data in all_key_datas:
if not key_data.keys: if not key_data.keys:
# May happen for broken versioned keys. # May happen for broken versioned keys.
...@@ -1363,17 +1313,12 @@ class ModuleCache(object): ...@@ -1363,17 +1313,12 @@ class ModuleCache(object):
_rmtree(os.path.join(self.dirname, filename), _rmtree(os.path.join(self.dirname, filename),
msg='old unversioned', level=logging.INFO, msg='old unversioned', level=logging.INFO,
ignore_nocleanup=True) ignore_nocleanup=True)
finally:
compilelock.release_lock()
def _on_atexit(self): def _on_atexit(self):
# Note: no need to call refresh() since it is called by clear_old(). # Note: no need to call refresh() since it is called by clear_old().
compilelock.get_lock() with compilelock.lock_ctx():
try:
self.clear_old() self.clear_old()
self.clear_unversioned() self.clear_unversioned()
finally:
compilelock.release_lock()
_logger.debug('Time spent checking keys: %s', _logger.debug('Time spent checking keys: %s',
self.time_spent_in_check_key) self.time_spent_in_check_key)
......
...@@ -8,6 +8,8 @@ import socket # only used for gethostname() ...@@ -8,6 +8,8 @@ import socket # only used for gethostname()
import time import time
import logging import logging
from contextlib import contextmanager
from theano import config from theano import config
from theano.configparser import AddConfigVar, IntParam from theano.configparser import AddConfigVar, IntParam
...@@ -44,6 +46,14 @@ def force_unlock(): ...@@ -44,6 +46,14 @@ def force_unlock():
release_lock() release_lock()
@contextmanager
def lock_ctx(lock_dir=None, keep_lock=False, **kw):
get_lock(lock_dir=lock_dir, **kw)
yield
if not keep_lock:
release_lock()
def get_lock(lock_dir=None, **kw): def get_lock(lock_dir=None, **kw):
""" """
Obtain lock on compilation directory. Obtain lock on compilation directory.
......
File mode changed from 100755 to 100644
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论