提交 120b1ee9 authored 作者: James Bergstra's avatar James Bergstra

added some more checks to key pickles in cache mechanism

上级 54b0b796
...@@ -7,7 +7,7 @@ import numpy.distutils #TODO: TensorType should handle this ...@@ -7,7 +7,7 @@ import numpy.distutils #TODO: TensorType should handle this
import compilelock # we will abuse the lockfile mechanism when reading and writing the registry import compilelock # we will abuse the lockfile mechanism when reading and writing the registry
_logger=logging.getLogger("theano.gof.cmodule") _logger=logging.getLogger("theano.gof.cmodule")
_logger.setLevel(logging.WARN) _logger.setLevel(logging.INFO)
def error(*args): def error(*args):
#sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n') #sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
...@@ -274,6 +274,11 @@ class ModuleCache(object): ...@@ -274,6 +274,11 @@ class ModuleCache(object):
info("Erasing broken file", key_pkl) info("Erasing broken file", key_pkl)
os.remove(key_pkl) os.remove(key_pkl)
continue continue
if not key[0]: #if the version is False
warning("ModuleCache.refresh() Found unversioned key in cache, deleting it.", key_pkl)
info("Erasing broken file", key_pkl)
os.remove(key_pkl)
continue
if key not in self.entry_from_key: if key not in self.entry_from_key:
entry = module_name_from_dir(root) entry = module_name_from_dir(root)
self.entry_from_key[key] = entry self.entry_from_key[key] = entry
...@@ -301,7 +306,9 @@ class ModuleCache(object): ...@@ -301,7 +306,9 @@ class ModuleCache(object):
info("deleting ModuleCache entry", entry) info("deleting ModuleCache entry", entry)
del self.entry_from_key[key] del self.entry_from_key[key]
self.loaded_key_pkl.remove(os.path.join(os.path.dirname(entry), 'key.pkl')) if key[0]:
#this is a versioned entry, so should have been on disk
self.loaded_key_pkl.remove(os.path.join(os.path.dirname(entry), 'key.pkl'))
finally: finally:
compilelock.release_lock() compilelock.release_lock()
...@@ -342,25 +349,26 @@ class ModuleCache(object): ...@@ -342,25 +349,26 @@ class ModuleCache(object):
assert name.startswith(location) assert name.startswith(location)
assert name not in self.module_from_name assert name not in self.module_from_name
assert key not in self.entry_from_key assert key not in self.entry_from_key
key_pkl = os.path.join(location, 'key.pkl') if _version: # save they key
key_file = file(key_pkl, 'w') key_pkl = os.path.join(location, 'key.pkl')
try: key_file = file(key_pkl, 'w')
cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL) try:
key_file.close() cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL)
key_broken = False key_file.close()
except cPickle.PicklingError: key_broken = False
key_file.close() except cPickle.PicklingError:
os.remove(key_pkl) key_file.close()
warning("Cache leak due to unpickle-able key", key) os.remove(key_pkl)
key_broken = True warning("Cache leak due to unpickle-able key", key)
key_broken = True
if _version and not key_broken:
key_from_file = cPickle.load(file(key_pkl)) if not key_broken:
if key != key_from_file: key_from_file = cPickle.load(file(key_pkl))
raise Exception("key not equal to unpickled version (Hint: verify the __eq__ and __hash__ functions for your Ops", (key, key_from_file)) if key != key_from_file:
raise Exception("key not equal to unpickled version (Hint: verify the __eq__ and __hash__ functions for your Ops", (key, key_from_file))
self.loaded_key_pkl.add(key_pkl)
self.entry_from_key[key] = name self.entry_from_key[key] = name
self.module_from_name[name] = module self.module_from_name[name] = module
self.loaded_key_pkl.add(key_pkl)
self.stats[2] += 1 self.stats[2] += 1
rval = module rval = module
...@@ -448,6 +456,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -448,6 +456,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
preargs.append('-fPIC') preargs.append('-fPIC')
no_opt = False no_opt = False
include_dirs = [distutils.sysconfig.get_python_inc()] + \ include_dirs = [distutils.sysconfig.get_python_inc()] + \
numpy.distutils.misc_util.get_numpy_include_dirs()\ numpy.distutils.misc_util.get_numpy_include_dirs()\
+ include_dirs + include_dirs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论