提交 3bcae5f9 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged -- no conflict

...@@ -28,7 +28,7 @@ def local_bitwidth(): ...@@ -28,7 +28,7 @@ def local_bitwidth():
return len('%x' % maxint) * 4 return len('%x' % maxint) * 4
_logger=logging.getLogger("theano.gof.cmodule") _logger=logging.getLogger("theano.gof.cmodule")
_logger.setLevel(logging.WARN) _logger.setLevel(logging.DEBUG)
def error(*args): def error(*args):
_logger.error("ERROR: "+' '.join(str(a) for a in args)) _logger.error("ERROR: "+' '.join(str(a) for a in args))
...@@ -264,6 +264,8 @@ def get_module_hash(src_code, key): ...@@ -264,6 +264,8 @@ def get_module_hash(src_code, key):
# if it does not match exactly what we expect. In the future we may # if it does not match exactly what we expect. In the future we may
# modify this behavior to be less strict and be able to accomodate # modify this behavior to be less strict and be able to accomodate
# changes to the key in an automatic way. # changes to the key in an automatic way.
# Note that if the key structure changes, the `get_safe_part` fucntion
# below may also need to be modified.
error_msg = ("This should not happen unless someone modified the code " error_msg = ("This should not happen unless someone modified the code "
"that defines the CLinker key, in which case you should " "that defines the CLinker key, in which case you should "
"ensure this piece of code is still valid (and this " "ensure this piece of code is still valid (and this "
...@@ -287,6 +289,31 @@ def get_module_hash(src_code, key): ...@@ -287,6 +289,31 @@ def get_module_hash(src_code, key):
return hash_from_code('\n'.join(to_hash)) return hash_from_code('\n'.join(to_hash))
def get_safe_part(key):
"""
Return a tuple containing a subset of `key`, to be used to find equal keys.
This tuple should only contain objects whose __eq__ and __hash__ methods
can be trusted (currently: the version part of the key, as well as the
md5 hash of the config options).
It is used to reduce the amount of key comparisons one has to go through
in order to find broken keys (i.e. keys with bad implementations of __eq__
or __hash__).
"""
version = key[0]
# This function should only be called on versioned keys.
assert version
# Find the md5 hash part.
c_link_key = key[1]
for key_element in c_link_key[1:]:
if isinstance(key_element, str) and key_element.startswith('md5:'):
md5 = key_element[4:]
break
return key[0] + (md5, )
class KeyData(object): class KeyData(object):
"""Used to store the key information in the cache.""" """Used to store the key information in the cache."""
...@@ -412,6 +439,9 @@ class ModuleCache(object): ...@@ -412,6 +439,9 @@ class ModuleCache(object):
"""Maps keys to the filename of a .so/.pyd. """Maps keys to the filename of a .so/.pyd.
""" """
similar_keys = {}
"""Maps a part-of-key to all keys that share this same part."""
module_hash_to_key_data = {} module_hash_to_key_data = {}
"""Maps hash of a module's code to its corresponding KeyData object.""" """Maps hash of a module's code to its corresponding KeyData object."""
...@@ -436,6 +466,7 @@ class ModuleCache(object): ...@@ -436,6 +466,7 @@ class ModuleCache(object):
self.module_from_name = dict(self.module_from_name) self.module_from_name = dict(self.module_from_name)
self.entry_from_key = dict(self.entry_from_key) self.entry_from_key = dict(self.entry_from_key)
self.module_hash_to_key_data = dict(self.module_hash_to_key_data) self.module_hash_to_key_data = dict(self.module_hash_to_key_data)
self.similar_keys = dict(self.similar_keys)
self.stats = [0, 0, 0] self.stats = [0, 0, 0]
self.check_for_broken_eq = check_for_broken_eq self.check_for_broken_eq = check_for_broken_eq
self.loaded_key_pkl = set() self.loaded_key_pkl = set()
...@@ -618,6 +649,11 @@ class ModuleCache(object): ...@@ -618,6 +649,11 @@ class ModuleCache(object):
# Assert that we have not already got this # Assert that we have not already got this
# entry somehow. # entry somehow.
assert entry not in self.module_from_name assert entry not in self.module_from_name
# Store safe part of versioned keys.
if key[0]:
self.similar_keys.setdefault(
get_safe_part(key),
[]).append(key)
else: else:
warning( warning(
"The same cache key is associated to " "The same cache key is associated to "
...@@ -867,6 +903,9 @@ class ModuleCache(object): ...@@ -867,6 +903,9 @@ class ModuleCache(object):
assert self.entry_from_key[k] == name assert self.entry_from_key[k] == name
else: else:
self.entry_from_key[k] = name self.entry_from_key[k] = name
if _version:
self.similar_keys.setdefault(get_safe_part(k),
[]).append(key)
if name in self.module_from_name: if name in self.module_from_name:
# May happen if we are re-using an existing module. # May happen if we are re-using an existing module.
...@@ -903,6 +942,18 @@ class ModuleCache(object): ...@@ -903,6 +942,18 @@ class ModuleCache(object):
"%s. Verify the __eq__ and __hash__ functions of your " "%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" % "Ops. The file is: %s. The key is: %s" %
(msg, key_pkl, key)) (msg, key_pkl, key))
# Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this
# part of the key is not broken.
for other in self.similar_keys.get(get_safe_part(key), []):
if other is not key and other == key and hash(other) != hash(key):
raise AssertionError(
"Found two keys that are equal but have a different hash. "
"Verify the __eq__ and __hash__ functions of your Ops. "
"The keys are:\n %s\nand\n %s\n(found in %s)." %
(other, key, key_pkl))
self.time_spent_in_check_key += time.time() - start_time self.time_spent_in_check_key += time.time() - start_time
age_thresh_del = 60*60*24*31#31 days age_thresh_del = 60*60*24*31#31 days
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论