提交 120b1ee9 authored 作者: James Bergstra's avatar James Bergstra

added some more checks to key pickles in cache mechanism

上级 54b0b796
...@@ -7,7 +7,7 @@ import numpy.distutils #TODO: TensorType should handle this ...@@ -7,7 +7,7 @@ import numpy.distutils #TODO: TensorType should handle this
import compilelock # we will abuse the lockfile mechanism when reading and writing the registry import compilelock # we will abuse the lockfile mechanism when reading and writing the registry
_logger=logging.getLogger("theano.gof.cmodule") _logger=logging.getLogger("theano.gof.cmodule")
_logger.setLevel(logging.WARN) _logger.setLevel(logging.INFO)
def error(*args): def error(*args):
#sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n') #sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
...@@ -274,6 +274,11 @@ class ModuleCache(object): ...@@ -274,6 +274,11 @@ class ModuleCache(object):
info("Erasing broken file", key_pkl) info("Erasing broken file", key_pkl)
os.remove(key_pkl) os.remove(key_pkl)
continue continue
if not key[0]: #if the version is False
warning("ModuleCache.refresh() Found unversioned key in cache, deleting it.", key_pkl)
info("Erasing broken file", key_pkl)
os.remove(key_pkl)
continue
if key not in self.entry_from_key: if key not in self.entry_from_key:
entry = module_name_from_dir(root) entry = module_name_from_dir(root)
self.entry_from_key[key] = entry self.entry_from_key[key] = entry
...@@ -301,6 +306,8 @@ class ModuleCache(object): ...@@ -301,6 +306,8 @@ class ModuleCache(object):
info("deleting ModuleCache entry", entry) info("deleting ModuleCache entry", entry)
del self.entry_from_key[key] del self.entry_from_key[key]
if key[0]:
#this is a versioned entry, so should have been on disk
self.loaded_key_pkl.remove(os.path.join(os.path.dirname(entry), 'key.pkl')) self.loaded_key_pkl.remove(os.path.join(os.path.dirname(entry), 'key.pkl'))
finally: finally:
...@@ -342,6 +349,7 @@ class ModuleCache(object): ...@@ -342,6 +349,7 @@ class ModuleCache(object):
assert name.startswith(location) assert name.startswith(location)
assert name not in self.module_from_name assert name not in self.module_from_name
assert key not in self.entry_from_key assert key not in self.entry_from_key
if _version: # save they key
key_pkl = os.path.join(location, 'key.pkl') key_pkl = os.path.join(location, 'key.pkl')
key_file = file(key_pkl, 'w') key_file = file(key_pkl, 'w')
try: try:
...@@ -354,13 +362,13 @@ class ModuleCache(object): ...@@ -354,13 +362,13 @@ class ModuleCache(object):
warning("Cache leak due to unpickle-able key", key) warning("Cache leak due to unpickle-able key", key)
key_broken = True key_broken = True
if _version and not key_broken: if not key_broken:
key_from_file = cPickle.load(file(key_pkl)) key_from_file = cPickle.load(file(key_pkl))
if key != key_from_file: if key != key_from_file:
raise Exception("key not equal to unpickled version (Hint: verify the __eq__ and __hash__ functions for your Ops", (key, key_from_file)) raise Exception("key not equal to unpickled version (Hint: verify the __eq__ and __hash__ functions for your Ops", (key, key_from_file))
self.loaded_key_pkl.add(key_pkl)
self.entry_from_key[key] = name self.entry_from_key[key] = name
self.module_from_name[name] = module self.module_from_name[name] = module
self.loaded_key_pkl.add(key_pkl)
self.stats[2] += 1 self.stats[2] += 1
rval = module rval = module
...@@ -448,6 +456,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -448,6 +456,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
preargs.append('-fPIC') preargs.append('-fPIC')
no_opt = False no_opt = False
include_dirs = [distutils.sysconfig.get_python_inc()] + \ include_dirs = [distutils.sysconfig.get_python_inc()] + \
numpy.distutils.misc_util.get_numpy_include_dirs()\ numpy.distutils.misc_util.get_numpy_include_dirs()\
+ include_dirs + include_dirs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论