提交 753ad528 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change the ModuleCache to only take the lock for actual compilation.

Also, once we acquire the lock do a refresh of the cache and try to pick up modules compiled by others.
上级 34a15a3f
...@@ -617,6 +617,16 @@ class ModuleCache(object): ...@@ -617,6 +617,16 @@ class ModuleCache(object):
Older modules will be deleted in ``clear_old``. Older modules will be deleted in ``clear_old``.
""" """
def _get_module(self, name):
if name not in self.module_from_name:
_logger.debug('loading name %s', name)
self.module_from_name[name] = dlimport(name)
self.stats[1] += 1
else:
_logger.debug('returning compiled module from cache %s', name)
self.stats[0] += 1
return self.module_from_name[name]
def refresh(self, age_thresh_use=None, delete_if_problem=False): def refresh(self, age_thresh_use=None, delete_if_problem=False):
"""Update cache data by walking the cache directory structure. """Update cache data by walking the cache directory structure.
...@@ -877,127 +887,88 @@ class ModuleCache(object): ...@@ -877,127 +887,88 @@ class ModuleCache(object):
return too_old_to_use return too_old_to_use
def module_from_key(self, key, fn=None, keep_lock=False, key_data=None): def _get_from_key(self, key, key_data=None):
""" """
:param fn: A callable object that will return an iterable object when Returns a module if the passed-in key is found in the cache
called, such that the first element in this iterable object is the and None otherwise.
source code of the module, and the last element is the module itself.
`fn` is called only if the key is not already in the cache, with
a single keyword argument `location` that is the path to the directory
where the module should be compiled.
:param key_data: If not None, it should be a KeyData object and the May raise ValueError if the key is malformed.
key parameter should be None. In this case, we use the info from the
KeyData object to recover the module, rather than the key itself. Note
that this implies the module already exists (and may or may not have
already been loaded).
""" """
# We should only use one of the two ways to get a module. name = None
assert key_data is None or key is None
rval = None
if key is not None: if key is not None:
assert key_data is None
try: try:
_version, _rest = key _version, _rest = key
except (TypeError, ValueError): except (TypeError, ValueError):
raise ValueError( raise ValueError(
"Invalid key. key must have form (version, rest)", key) "Invalid key. key must have form (version, rest)", key)
name = None if name in self.entry_from_key:
if key is not None and key in self.entry_from_key:
# We have seen this key either in this process or previously.
name = self.entry_from_key[key] name = self.entry_from_key[key]
elif key_data is not None:
name = key_data.get_entry()
if name is not None:
# This is an existing module we can recover.
if name not in self.module_from_name:
_logger.debug('loading name %s', name)
self.module_from_name[name] = dlimport(name)
self.stats[1] += 1
else: else:
self.stats[0] += 1 assert key_data is not None
_logger.debug('returning compiled module from cache %s', name) name = key_data.get_entry()
rval = self.module_from_name[name] if name is None:
else: return None
hash_key = hash(key) return self._get_module(name)
key_data = None
# We have never seen this key before.
# We acquire the lock later only if we were able to
# generate C code. Otherwise, we would take the lock for ops
# that have only a perform().
lock_taken = False
# This try/finally block ensures that the lock is released once we
# are done writing in the cache file or after raising an exception.
try:
# Embedding two try statements for Python 2.4 compatibility
# (cannot do try / except / finally).
try:
location = dlimport_workdir(self.dirname)
except OSError, e:
_logger.error(e)
if e.errno == 31:
_logger.error('There are %i files in %s',
len(os.listdir(config.compiledir)),
config.compiledir)
raise
try:
compile_steps = fn(location=location).__iter__()
# Check if we already know a module with the same hash.
# If we do, then there is no need to even compile it.
duplicated_module = False
# The first compilation step is to yield the source code.
src_code = next(compile_steps)
module_hash = get_module_hash(src_code, key)
# The op has c_code, so take the lock.
compilelock.get_lock()
lock_taken = True
if not os.path.exists(location):
# Temporary fix, we should make sure it don't
# get deleted by the clear*() fct.
os.makedirs(location)
def _get_from_hash(self, module_hash, key, keep_lock=False):
if module_hash in self.module_hash_to_key_data: if module_hash in self.module_hash_to_key_data:
_logger.debug("Duplicated module! Will re-use the " _logger.debug("Duplicated module! Will re-use the "
"previous one") "previous one")
duplicated_module = True
# Load the already existing module.
key_data = self.module_hash_to_key_data[module_hash] key_data = self.module_hash_to_key_data[module_hash]
# Note that we do not pass the `fn` argument, since it module = self._get_from_key(None, key_data)
# should not be used considering that the module should lock_taken = False
# already be compiled.
module = self.module_from_key(key=None,
key_data=key_data)
name = module.__file__
# Add current key to the set of keys associated to the
# same module. We only save the KeyData object of
# versioned modules.
try: try:
key_data.add_key(key, save_pkl=bool(_version)) compilelock.get_lock()
lock_taken = True
key_data.add_key(key, save_pkl=bool(key[0]))
key_broken = False key_broken = False
except cPickle.PicklingError: except cPickle.PicklingError:
# This should only happen if we tried to save the
# pickled file.
assert _version
# The key we are trying to add is broken: we will
# not add it after all.
key_data.remove_key(key) key_data.remove_key(key)
key_broken = True key_broken = True
finally:
if (_version and not key_broken and if lock_taken and not keep_lock:
compilelock.release_lock()
if (key[0] and not key_broken and
self.check_for_broken_eq): self.check_for_broken_eq):
self.check_key(key, key_data.key_pkl) self.check_key(key, key_data.key_pkl)
self._update_mappings(key, key_data, module.__file__)
return module
else:
return None
# We can delete the work directory. def _update_mappings(self, key, key_data, name):
_rmtree(location, ignore_nocleanup=True, all_keys = key_data.keys
msg='temporary workdir of duplicated module') if not all_keys:
all_keys = [key]
assert key in all_keys
for k in all_keys:
if k in self.entry_from_key:
assert self.entry_from_key[k] == name
else: else:
# Will fail if there is an error compiling the C code. self.entry_from_key[k] = name
# The exception will be caught and the work dir will be if key[0]:
# deleted. self.similar_keys.setdefault(get_safe_part(k),
[]).append(key)
def _compile_code(self, compile_steps):
"""
Compiles the passed-in source code.
This expects that the compile lock is held during the call.
"""
location = None
try:
location = dlimport_workdir(self.dirname)
except OSError, e:
_logger.error(e)
if e.errno == 31:
_logger.error('There are %i files in %s',
len(os.listdir(config.compiledir)),
config.compiledir)
raise
try:
while True: while True:
try: try:
# The module should be returned by the last # The module should be returned by the last
...@@ -1005,18 +976,25 @@ class ModuleCache(object): ...@@ -1005,18 +976,25 @@ class ModuleCache(object):
module = next(compile_steps) module = next(compile_steps)
except StopIteration: except StopIteration:
break break
# Obtain path to the '.so' module file.
name = module.__file__ name = module.__file__
assert name.startswith(location)
assert name not in self.module_from_name
return module
except Exception:
_rmtree(location, ignore_if_missing=True,
msg='exception during compilation')
raise
def _add_to_cache(self, module, key, module_hash):
"""
This function expects the compile lock to be held.
"""
name = module.__file__
_logger.debug("Adding module to cache %s %s", _logger.debug("Adding module to cache %s %s",
key, name) key, name)
assert name.startswith(location)
assert name not in self.module_from_name
# Changing the hash of the key is not allowed during # Changing the hash of the key is not allowed during
# compilation. That is the only cause found that makes # compilation. That is the only cause found that makes
# the following assert fail. # the following assert fail.
assert hash(key) == hash_key
assert key not in self.entry_from_key assert key not in self.entry_from_key
key_pkl = os.path.join(location, 'key.pkl') key_pkl = os.path.join(location, 'key.pkl')
...@@ -1027,93 +1005,82 @@ class ModuleCache(object): ...@@ -1027,93 +1005,82 @@ class ModuleCache(object):
key_pkl=key_pkl, key_pkl=key_pkl,
entry=name) entry=name)
# Note that we only save KeyData objects associated to if key[0]:
# versioned modules. So for unversioned key, the
# `key_pkl` field of the KeyData object will be a
# non-existing file (which does not matter since it
# will not be accessed).
if _version:
try: try:
key_data.save_pkl() key_data.save_pkl()
key_broken = False key_broken = False
except cPickle.PicklingError: except cPickle.PicklingError:
key_broken = True key_broken = True
# Remove key from the KeyData object, to make key_data.remove_key(key)
# sure we never try to save it again.
# We still keep the KeyData object and save it
# so that the module can be re-used in the
# future.
key_data.keys = set()
key_data.save_pkl() key_data.save_pkl()
if not key_broken and self.check_for_broken_eq: if not key_broken and self.check_for_broken_eq:
self.check_key(key, key_pkl) self.check_key(key, key_pkl)
# Adding the KeyData file to this set means it is a
# versioned module.
self.loaded_key_pkl.add(key_pkl) self.loaded_key_pkl.add(key_pkl)
elif config.cmodule.warn_no_version: elif config.cmodule.warn_no_version:
key_flat = flatten(key) key_flat = flatten(key)
ops = [k for k in key_flat ops = [k for k in key_flat if isinstance(k, theano.Op)]
if isinstance(k, theano.Op)]
_logger.warning("not all the" _logger.warning("not all the"
" following op(s) implement" " following op(s) implement"
" c_code_cache_version(). This makes them" " c_code_cache_version(). This makes them"
" recompiled for each process." + str(ops)) " recompiled for each process." + str(ops))
self._update_mappings(key, key_data, module)
return key_data
# Map the new module to its KeyData object. Note that def module_from_key(self, key, fn=None, keep_lock=False):
# we need to do it regardless of whether the key is """
# versioned or not if we want to be able to re-use this :param fn: A callable object that will return an iterable object when
# module inside the same process. called, such that the first element in this iterable object is the
self.module_hash_to_key_data[module_hash] = key_data source code of the module, and the last element is the module itself.
`fn` is called only if the key is not already in the cache, with
a single keyword argument `location` that is the path to the directory
where the module should be compiled.
"""
# Is the module in the cache?
module = self._get_from_key(key)
if module is not None:
return module
except Exception: # Is the source code already in the cache?
# This may happen e.g. when an Op has no C implementation. compile_steps = fn(location=location).__iter__()
# In any case, we do not want to keep around the temporary src_code = next(compile_steps)
# work directory, as it may cause trouble if we create too module_hash = get_module_hash(src_code, key)
# many of these. The 'ignore_if_missing' flag is set just module = self._get_from_hash(module_hash, key, keep_lock=keep_lock)
# in case this directory would have already been deleted. if module is not None:
_rmtree(location, ignore_if_missing=True, return module
msg=('exception -- '
'typically means no C implementation')) # Compile the module since it's not cached
raise try:
# The op has c_code, so take the lock.
compilelock.get_lock()
lock_taken = True
# Maybe somebody else compiled it for us while we
# where waiting for the lock. Try to load it again
self.refresh()
module = self._get_from_key(key)
if module is not None:
return module
module = self._get_from_hash(module_hash, key, keep_lock=keep_lock)
if module is not None:
return module
hash_key = hash(key)
module = self._compile_module(compile_steps)
# Changing the hash of the key is not allowed during
# compilation.
assert hash(key) == hash_key
key_data = self._add_to_cache(module, key)
self.module_hash_to_key_data[module_hash] = key_data
finally: finally:
# Release lock if needed. # Release lock if needed.
if not keep_lock and lock_taken: if not keep_lock and lock_taken:
compilelock.release_lock() compilelock.release_lock()
# Update map from key to module name for all keys associated to
# this same module.
all_keys = key_data.keys
if not all_keys:
# Should only happen for broken keys.
assert key_broken
all_keys = [key]
else:
assert key in key_data.keys
for k in all_keys:
if k in self.entry_from_key:
# If we had already seen this key, then it should be
# associated to the same module.
assert self.entry_from_key[k] == name
else:
self.entry_from_key[k] = name
if _version:
self.similar_keys.setdefault(get_safe_part(k),
[]).append(key)
if name in self.module_from_name:
# May happen if we are re-using an existing module.
assert duplicated_module
assert self.module_from_name[name] is module
else:
self.module_from_name[name] = module
self.stats[2] += 1 self.stats[2] += 1
rval = module return module
#_logger.debug('stats %s %i', self.stats, sum(self.stats))
return rval
def check_key(self, key, key_pkl): def check_key(self, key, key_pkl):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论