提交 3bcae5f9 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged -- no conflict

......@@ -28,7 +28,7 @@ def local_bitwidth():
return len('%x' % maxint) * 4
_logger=logging.getLogger("theano.gof.cmodule")
_logger.setLevel(logging.WARN)
_logger.setLevel(logging.DEBUG)
def error(*args):
_logger.error("ERROR: "+' '.join(str(a) for a in args))
......@@ -264,6 +264,8 @@ def get_module_hash(src_code, key):
# if it does not match exactly what we expect. In the future we may
# modify this behavior to be less strict and be able to accomodate
# changes to the key in an automatic way.
# Note that if the key structure changes, the `get_safe_part` fucntion
# below may also need to be modified.
error_msg = ("This should not happen unless someone modified the code "
"that defines the CLinker key, in which case you should "
"ensure this piece of code is still valid (and this "
......@@ -287,6 +289,31 @@ def get_module_hash(src_code, key):
return hash_from_code('\n'.join(to_hash))
def get_safe_part(key):
"""
Return a tuple containing a subset of `key`, to be used to find equal keys.
This tuple should only contain objects whose __eq__ and __hash__ methods
can be trusted (currently: the version part of the key, as well as the
md5 hash of the config options).
It is used to reduce the amount of key comparisons one has to go through
in order to find broken keys (i.e. keys with bad implementations of __eq__
or __hash__).
"""
version = key[0]
# This function should only be called on versioned keys.
assert version
# Find the md5 hash part.
c_link_key = key[1]
for key_element in c_link_key[1:]:
if isinstance(key_element, str) and key_element.startswith('md5:'):
md5 = key_element[4:]
break
return key[0] + (md5, )
class KeyData(object):
"""Used to store the key information in the cache."""
......@@ -412,6 +439,9 @@ class ModuleCache(object):
"""Maps keys to the filename of a .so/.pyd.
"""
similar_keys = {}
"""Maps a part-of-key to all keys that share this same part."""
module_hash_to_key_data = {}
"""Maps hash of a module's code to its corresponding KeyData object."""
......@@ -436,6 +466,7 @@ class ModuleCache(object):
self.module_from_name = dict(self.module_from_name)
self.entry_from_key = dict(self.entry_from_key)
self.module_hash_to_key_data = dict(self.module_hash_to_key_data)
self.similar_keys = dict(self.similar_keys)
self.stats = [0, 0, 0]
self.check_for_broken_eq = check_for_broken_eq
self.loaded_key_pkl = set()
......@@ -618,6 +649,11 @@ class ModuleCache(object):
# Assert that we have not already got this
# entry somehow.
assert entry not in self.module_from_name
# Store safe part of versioned keys.
if key[0]:
self.similar_keys.setdefault(
get_safe_part(key),
[]).append(key)
else:
warning(
"The same cache key is associated to "
......@@ -867,6 +903,9 @@ class ModuleCache(object):
assert self.entry_from_key[k] == name
else:
self.entry_from_key[k] = name
if _version:
self.similar_keys.setdefault(get_safe_part(k),
[]).append(key)
if name in self.module_from_name:
# May happen if we are re-using an existing module.
......@@ -903,6 +942,18 @@ class ModuleCache(object):
"%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" %
(msg, key_pkl, key))
# Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this
# part of the key is not broken.
for other in self.similar_keys.get(get_safe_part(key), []):
if other is not key and other == key and hash(other) != hash(key):
raise AssertionError(
"Found two keys that are equal but have a different hash. "
"Verify the __eq__ and __hash__ functions of your Ops. "
"The keys are:\n %s\nand\n %s\n(found in %s)." %
(other, key, key_pkl))
self.time_spent_in_check_key += time.time() - start_time
age_thresh_del = 60*60*24*31#31 days
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论