提交 d24ce080 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Cache optimization: modules are not duplicated anymore in the cache

The idea is that even when the key changes, the resulting code and compiled module can remain the same. We avoid this by keeping a hash of the C code (as well as of the compilation options) and when this hash matches that of an existing module, we re-use this module instead of having two different modules doing the same thing. Two improvements should be made to the current implementation: 1. Currently a module is still compiled before we detect that it is duplicated, while we should avoid this compilation step. 2. It may be the case that different code or compilation options actually lead to the exact same compiled module. It would be nice to be able to detect this situation as well.
上级 a1d395b1
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
""" """
import os, tempfile, StringIO, sys, logging, subprocess, cPickle, atexit, time, shutil, stat import os, tempfile, StringIO, sys, logging, subprocess, cPickle, atexit, time, shutil, stat
import distutils.sysconfig import distutils.sysconfig
from theano.configparser import config
import numpy.distutils #TODO: TensorType should handle this import numpy.distutils #TODO: TensorType should handle this
import sys import theano
from theano.configparser import config
from theano.gof.cc import hash_from_code, hash_from_file
import compilelock # we will abuse the lockfile mechanism when reading and writing the registry import compilelock # we will abuse the lockfile mechanism when reading and writing the registry
from theano.configparser import TheanoConfigParser, AddConfigVar, EnumStr, StrParam, IntParam, FloatParam, BoolParam from theano.configparser import TheanoConfigParser, AddConfigVar, EnumStr, StrParam, IntParam, FloatParam, BoolParam
...@@ -202,6 +204,78 @@ def module_name_from_dir(dirname): ...@@ -202,6 +204,78 @@ def module_name_from_dir(dirname):
name, = [file for file in files if file.endswith('.so') or file.endswith('.pyd')] name, = [file for file in files if file.endswith('.so') or file.endswith('.pyd')]
return os.path.join(dirname, name) return os.path.join(dirname, name)
def get_module_hash(module_file, key):
"""
Return an MD5 hash that identifies a module.
This hash takes into account:
1. The 'mod.cpp' file associated used to compile `module_file`.
2. The compiler options defined in `key`.
"""
source_code = os.path.join(os.path.dirname(module_file), 'mod.cpp')
source_hash = hash_from_file(source_code)
c_link_key = key[1]
# Currently, in order to catch potential bugs early, we are very
# convervative about the structure of the key and raise an exception
# if it does not match exactly what we expect. In the future we may
# modify this behavior to be less strict and be able to accomodate
# changes to the key in an automatic way.
error_msg = ("This should not happen unless someone modified the code "
"that defines the CLinker key, in which case you should "
"ensure this piece of code is still valid (and this "
"AssertionError may be removed or modified to accomodate "
"this change)")
assert (c_link_key[0] == 'CLinker.cmodule_key', error_msg)
to_hash = [source_hash]
for key_element in c_link_key[1:]:
if isinstance(key_element, tuple):
to_hash += list(key_element)
elif isinstance(key_element, str):
if key_element.startswith('md5:'):
# This is the md5 hash of the config options. We can stop
# here.
break
else:
raise AssertionError(error_msg)
else:
raise AssertionError(error_msg)
return hash_from_code('\n'.join(to_hash))
class KeyData(object):
"""Used to store the key information in the cache."""
def __init__(self, keys, module_hash, key_pkl):
"""
Constructor.
:param keys: Set of keys that are associated to the exact same module.
:param module_hash: Hash identifying the module (it should hash both
the code and the compilation options).
:param key_pkl: Path to the file in which this KeyData object should be
pickled.
"""
self.keys = keys
self.module_hash = module_hash
self.key_pkl = key_pkl
def add_key(self, key):
"""Add a key to the `keys` set, and update the pickled file."""
assert key not in self.keys
self.keys.add(key)
self.save_pkl()
def save_pkl(self):
"""Dump this object into its `key_pkl` file."""
cPickle.dump(self, open(self.key_pkl, 'wb'),
protocol=cPickle.HIGHEST_PROTOCOL)
class ModuleCache(object): class ModuleCache(object):
"""Interface to the cache of dynamically compiled modules on disk """Interface to the cache of dynamically compiled modules on disk
...@@ -239,6 +313,9 @@ class ModuleCache(object): ...@@ -239,6 +313,9 @@ class ModuleCache(object):
"""Maps keys to the filename of a .so/.pyd. """Maps keys to the filename of a .so/.pyd.
""" """
loaded_modules_hash = {}
"""Maps hash of a module's code to its corresponding KeyData object."""
stats = [] stats = []
"""A list with counters for the number of hits, loads, compiles issued by module_from_key() """A list with counters for the number of hits, loads, compiles issued by module_from_key()
""" """
...@@ -260,6 +337,7 @@ class ModuleCache(object): ...@@ -260,6 +337,7 @@ class ModuleCache(object):
self.dirname = dirname self.dirname = dirname
self.module_from_name = dict(self.module_from_name) self.module_from_name = dict(self.module_from_name)
self.entry_from_key = dict(self.entry_from_key) self.entry_from_key = dict(self.entry_from_key)
self.loaded_modules_hash = dict(self.loaded_modules_hash)
self.stats = [0, 0, 0] self.stats = [0, 0, 0]
if force_fresh is not None: if force_fresh is not None:
self.force_fresh = force_fresh self.force_fresh = force_fresh
...@@ -325,33 +403,76 @@ class ModuleCache(object): ...@@ -325,33 +403,76 @@ class ModuleCache(object):
info("Erasing broken cache directory", key_pkl) info("Erasing broken cache directory", key_pkl)
shutil.rmtree(root) shutil.rmtree(root)
continue continue
if (time_now - last_access_time(entry))<self.age_thresh_use: if (time_now - last_access_time(entry)) < self.age_thresh_use:
debug('refresh adding', key_pkl) debug('refresh adding', key_pkl)
try: try:
key = cPickle.load(open(key_pkl, 'rb')) key_data = cPickle.load(open(key_pkl, 'rb'))
except: except:
info("ModuleCache.refresh() Failed to unpickle cache key", key_pkl) info("ModuleCache.refresh() Failed to unpickle "
if 0: "cache file", key_pkl)
if False:
info("Erasing broken cache directory", key_pkl) info("Erasing broken cache directory", key_pkl)
shutil.rmtree(root) shutil.rmtree(root)
else: else:
## This exception is often triggered by keys that contain # This exception is often triggered by keys that contain
# references to classes that have not yet been imported. They are # references to classes that have not yet been imported. They are
# not necessarily broken # not necessarily broken
pass pass
continue continue
if not key[0]: #if the version is False if not isinstance(key_data, KeyData):
warning("ModuleCache.refresh() Found unversioned key in cache, deleting it.", key_pkl) # Backward-compatibility with older cache mechanism
# that used single keys with no hash of the
# compiled file.
key_data = KeyData(
keys=set([key_data]),
module_hash=get_module_hash(entry, key_data),
key_pkl=key_pkl)
debug("Updating cache key to new format", key_pkl)
key_data.save_pkl()
# Find unversioned keys.
to_del = [key for key in key_data.keys if not key[0]]
if to_del:
warning("ModuleCache.refresh() Found unversioned "
"key in cache, removing it.", key_pkl)
if len(to_del) == len(key_data.keys):
# All keys were unversioned.
info("Erasing broken cache directory", key_pkl) info("Erasing broken cache directory", key_pkl)
shutil.rmtree(root) shutil.rmtree(root)
continue continue
else:
# Fix the pickled file to only keep the
# versioned keys.
info("Fixing broken cache directory", key_pkl)
key_data.keys = set(
[key for key in key_data.keys
if key[0]])
key_data.save_pkl()
for key in key_data.keys:
if key not in self.entry_from_key: if key not in self.entry_from_key:
self.entry_from_key[key] = entry self.entry_from_key[key] = entry
# assert that we haven't already got this entry somehow # Assert that we have not already got this
# entry somehow.
assert entry not in self.module_from_name assert entry not in self.module_from_name
self.loaded_key_pkl.add(key_pkl) self.loaded_key_pkl.add(key_pkl)
# Remember the map from a module's hash to the KeyData
# object associated with it.
mod_hash = key_data.module_hash
if mod_hash in self.loaded_modules_hash:
# This should not happen anymore, but may happen
# with the previous cache mechanism, that did not
# ensure uniqueness of the compiled modules.
# TODO Convert into an error in the future.
warning(
"Found duplicated modules in the cache, you "
"are probably using an old cache. Clear it "
"with 'theano-cache clear' to benefit from "
"recent cache optimizations.")
else:
self.loaded_modules_hash[mod_hash] = key_data
else: else:
too_old_to_use.append(entry) too_old_to_use.append(entry)
...@@ -449,12 +570,36 @@ class ModuleCache(object): ...@@ -449,12 +570,36 @@ class ModuleCache(object):
assert hash(key) == hash_key assert hash(key) == hash_key
assert key not in self.entry_from_key assert key not in self.entry_from_key
if _version: # save they key # Check if we already know a module with the same hash.
duplicated_module = False
module_hash = get_module_hash(name, key)
if module_hash in self.loaded_modules_hash:
debug("Duplicated module! Will re-use the previous one")
duplicated_module = True
# Load the already existing module.
key_data = self.loaded_modules_hash[module_hash]
module = self.module_from_key(
key=key_data.keys.__iter__().next(),
keep_lock=True)
# Add current key to the set of keys associated to the same
# module.
key_data.add_key(key)
# We can delete this module.
debug("Deleting: ", os.path.dirname(name))
shutil.rmtree(os.path.dirname(name))
name = module.__file__
if not duplicated_module and _version: # save the key
key_pkl = os.path.join(location, 'key.pkl') key_pkl = os.path.join(location, 'key.pkl')
key_data = KeyData(
keys=set([key]),
module_hash=get_module_hash(name, key),
key_pkl=key_pkl)
# Note that using a binary file is important under Windows. # Note that using a binary file is important under Windows.
key_file = open(key_pkl, 'wb') key_file = open(key_pkl, 'wb')
try: try:
cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL) cPickle.dump(key_data, key_file,
cPickle.HIGHEST_PROTOCOL)
key_file.close() key_file.close()
key_broken = False key_broken = False
except cPickle.PicklingError: except cPickle.PicklingError:
...@@ -465,7 +610,9 @@ class ModuleCache(object): ...@@ -465,7 +610,9 @@ class ModuleCache(object):
if not key_broken: if not key_broken:
try: try:
key_from_file = cPickle.load(open(key_pkl, 'rb')) kd_from_file = cPickle.load(open(key_pkl, 'rb'))
assert len(kd_from_file.keys) == 1
key_from_file = kd_from_file.keys.__iter__().next()
if key != key_from_file: if key != key_from_file:
raise Exception( raise Exception(
"key not equal to unpickled version (Hint:" "key not equal to unpickled version (Hint:"
...@@ -474,9 +621,11 @@ class ModuleCache(object): ...@@ -474,9 +621,11 @@ class ModuleCache(object):
# Adding the key file to this set means it is a # Adding the key file to this set means it is a
# versioned key. # versioned key.
self.loaded_key_pkl.add(key_pkl) self.loaded_key_pkl.add(key_pkl)
self.loaded_modules_hash[module_hash] = key_data
except cPickle.UnpicklingError: except cPickle.UnpicklingError:
warning('Cache failure due to un-loadable key', warning('Cache failure due to un-loadable key',
key) key)
finally: finally:
# Release lock if needed. # Release lock if needed.
if not keep_lock: if not keep_lock:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论