提交 3f0d61d4 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Faster mechanism to check broken __eq__ and __hash__ implementations in C cache

The previous mechanism could take over a minute with a big cache. It did two checks: 1. (expensive check) Compare pair-wise all keys loaded, and find those that are equal even though they have a different hash. 2. (cheap one) Ensure that unpickling a pickled key yields a key that is equal to the original one. Those checks are now replaced by a new check that is performed each time a new key is saved in the cache. We reload the pickled KeyData object and ensure that it contains exactly one unpickled key equal to the key originally saved. It is obvious it will catch errors that would have been caught by check #2. It will also catch errors that would have been caught by check #1 because if two keys are equal, they must yield a similar module hash, and thus they will be stored in the same KeyData object: if they have a different key hash, they will both appear in the 'keys' set of the KeyData object, and thus the new check will complain there is more than one key found to be equal to the one we just saved. NB: Also removed a couple TODOs now that I understand better why some code was written that way.
上级 5f1c7026
...@@ -430,9 +430,9 @@ class ModuleCache(object): ...@@ -430,9 +430,9 @@ class ModuleCache(object):
def __init__(self, dirname, force_fresh=None, check_for_broken_eq=True, def __init__(self, dirname, force_fresh=None, check_for_broken_eq=True,
do_refresh=True): do_refresh=True):
""" """
:param check_for_broken_eq: A bad __eq__ implemenation can break this cache mechanism. :param check_for_broken_eq: A bad __eq__ implementation can break this
This option turns on a not-too-expensive sanity check during the load of an old cache cache mechanism. This option turns on a not-too-expensive sanity check
file. every time a new key is added to the cache.
:param do_refresh: If True, then the ``refresh`` method will be called :param do_refresh: If True, then the ``refresh`` method will be called
in the constructor. in the constructor.
...@@ -445,35 +445,13 @@ class ModuleCache(object): ...@@ -445,35 +445,13 @@ class ModuleCache(object):
if force_fresh is not None: if force_fresh is not None:
# TODO Where is / was `force_fresh` used? # TODO Where is / was `force_fresh` used?
self.force_fresh = force_fresh self.force_fresh = force_fresh
self.check_for_broken_eq = check_for_broken_eq
self.loaded_key_pkl = set() self.loaded_key_pkl = set()
self.time_spent_in_check_key = 0
if do_refresh: if do_refresh:
self.refresh() self.refresh()
start = time.time()
if check_for_broken_eq:
# Speed up comparison by only comparing keys for which the version
# part is equal (since the version part is supposed to be made of
# integer tuples, we can assume its hash is properly implemented).
version_to_keys = {}
for key in self.entry_from_key:
version_to_keys.setdefault(key[0], []).append(key)
for k0 in self.entry_from_key:
# Compare `k0` to all keys `k1` with the same version part.
for k1 in version_to_keys[k0[0]]:
if k0 is not k1 and k0 == k1:
warning(("The __eq__ and __hash__ functions are broken for some element"
" in the following two keys. The cache mechanism will say that"
" graphs like this need recompiling, when they could have been"
" retrieved:"))
warning("Key 0:", k0)
warning("Entry 0:", self.entry_from_key[k0])
warning("hash 0:", hash(k0))
warning("Key 1:", k1)
warning("Entry 1:", self.entry_from_key[k1])
warning("hash 1:", hash(k1))
debug('Time needed to check broken equality / hash: %s' % (time.time() - start))
age_thresh_use = 60*60*24*24 age_thresh_use = 60*60*24*24
""" """
The default age threshold (in seconds) for cache files we want to use. The default age threshold (in seconds) for cache files we want to use.
...@@ -559,19 +537,18 @@ class ModuleCache(object): ...@@ -559,19 +537,18 @@ class ModuleCache(object):
level='warning') level='warning')
continue continue
except: except:
# TODO Note that in a development version, it may
# be better to raise exceptions instead of silently
# catching them.
unpickle_failure() unpickle_failure()
if delete_if_problem: if delete_if_problem:
_rmtree(root, ignore_nocleanup=True, _rmtree(root, ignore_nocleanup=True,
msg='broken cache directory', msg='broken cache directory',
level='info') level='info')
else: else:
# This exception is often triggered by keys that contain # This exception is often triggered by keys
# references to classes that have not yet been imported. They are # that contain references to classes that have
# not necessarily broken. # not yet been imported (e.g. when running two
# TODO But is there a reason to keep them? # different Theano-based scripts). They are not
# necessarily broken, but we cannot load them
# here.
pass pass
continue continue
...@@ -788,6 +765,7 @@ class ModuleCache(object): ...@@ -788,6 +765,7 @@ class ModuleCache(object):
# modules. # modules.
try: try:
key_data.add_key(key, save_pkl=bool(_version)) key_data.add_key(key, save_pkl=bool(_version))
key_broken = False
except cPickle.PicklingError: except cPickle.PicklingError:
# This should only happen if we tried to save the # This should only happen if we tried to save the
# pickled file. # pickled file.
...@@ -797,6 +775,9 @@ class ModuleCache(object): ...@@ -797,6 +775,9 @@ class ModuleCache(object):
key_data.remove_key(key) key_data.remove_key(key)
key_broken = True key_broken = True
if not key_broken and self.check_for_broken_eq:
self.check_key(key, key_data.key_pkl)
# We can delete the work directory. # We can delete the work directory.
_rmtree(location, ignore_nocleanup=True, _rmtree(location, ignore_nocleanup=True,
msg='temporary workdir of duplicated module') msg='temporary workdir of duplicated module')
...@@ -846,23 +827,8 @@ class ModuleCache(object): ...@@ -846,23 +827,8 @@ class ModuleCache(object):
key_data.keys = set() key_data.keys = set()
key_data.save_pkl() key_data.save_pkl()
# TODO We should probably have a similar sanity check if not key_broken and self.check_for_broken_eq:
# when we add a new key to an existing KeyData object, self.check_key(key, key_pkl)
# not just when we create a brand new one.
if not key_broken:
try:
kd2 = cPickle.load(open(key_pkl, 'rb'))
assert len(kd2.keys) == 1
key_from_file = kd2.keys.__iter__().next()
if key != key_from_file:
raise Exception(
"Key not equal to unpickled version "
"(Hint: verify the __eq__ and "
"__hash__ functions for your Ops",
(key, key_from_file))
except cPickle.UnpicklingError:
warning('Cache failure due to un-loadable key',
key)
# Adding the KeyData file to this set means it is a # Adding the KeyData file to this set means it is a
# versioned module. # versioned module.
...@@ -919,6 +885,31 @@ class ModuleCache(object): ...@@ -919,6 +885,31 @@ class ModuleCache(object):
#debug('stats', self.stats, sum(self.stats)) #debug('stats', self.stats, sum(self.stats))
return rval return rval
def check_key(self, key, key_pkl):
"""
Perform checks to detect broken __eq__ / __hash__ implementations.
:param key: The key to be checked.
:param key_pkl: Its associated pickled file containing a KeyData.
"""
start_time = time.time()
# Verify that when we reload the KeyData from the pickled file, the
# same key can be found in it, and is not equal to more than one
# other key.
key_data = cPickle.load(open(key_pkl, 'rb'))
found = sum(key == other_key for other_key in key_data.keys)
msg = ''
if found == 0:
msg = 'Key not found in unpickled KeyData file'
elif found > 1:
msg = 'Multiple equal keys found in unpickled KeyData file'
if msg:
raise AssertionError(
"%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" %
(msg, key_pkl, key))
self.time_spent_in_check_key += time.time() - start_time
age_thresh_del = 60*60*24*31#31 days age_thresh_del = 60*60*24*31#31 days
age_thresh_del_unversioned = 60*60*24*7#7 days age_thresh_del_unversioned = 60*60*24*7#7 days
...@@ -1097,6 +1088,7 @@ class ModuleCache(object): ...@@ -1097,6 +1088,7 @@ class ModuleCache(object):
self.clear_unversioned() self.clear_unversioned()
finally: finally:
compilelock.release_lock() compilelock.release_lock()
debug('Time spent checking keys: %s' % self.time_spent_in_check_key)
def _rmtree(parent, ignore_nocleanup=False, msg='', level='debug', def _rmtree(parent, ignore_nocleanup=False, msg='', level='debug',
ignore_if_missing=False): ignore_if_missing=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论