提交 120b1ee9 authored 作者: James Bergstra's avatar James Bergstra

added some more checks to key pickles in cache mechanism

上级 54b0b796
......@@ -7,7 +7,7 @@ import numpy.distutils #TODO: TensorType should handle this
import compilelock # we will abuse the lockfile mechanism when reading and writing the registry
_logger=logging.getLogger("theano.gof.cmodule")
_logger.setLevel(logging.WARN)
_logger.setLevel(logging.INFO)
def error(*args):
#sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
......@@ -274,6 +274,11 @@ class ModuleCache(object):
info("Erasing broken file", key_pkl)
os.remove(key_pkl)
continue
if not key[0]: #if the version is False
warning("ModuleCache.refresh() Found unversioned key in cache, deleting it.", key_pkl)
info("Erasing broken file", key_pkl)
os.remove(key_pkl)
continue
if key not in self.entry_from_key:
entry = module_name_from_dir(root)
self.entry_from_key[key] = entry
......@@ -301,7 +306,9 @@ class ModuleCache(object):
info("deleting ModuleCache entry", entry)
del self.entry_from_key[key]
self.loaded_key_pkl.remove(os.path.join(os.path.dirname(entry), 'key.pkl'))
if key[0]:
#this is a versioned entry, so should have been on disk
self.loaded_key_pkl.remove(os.path.join(os.path.dirname(entry), 'key.pkl'))
finally:
compilelock.release_lock()
......@@ -342,25 +349,26 @@ class ModuleCache(object):
assert name.startswith(location)
assert name not in self.module_from_name
assert key not in self.entry_from_key
key_pkl = os.path.join(location, 'key.pkl')
key_file = file(key_pkl, 'w')
try:
cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL)
key_file.close()
key_broken = False
except cPickle.PicklingError:
key_file.close()
os.remove(key_pkl)
warning("Cache leak due to unpickle-able key", key)
key_broken = True
if _version and not key_broken:
key_from_file = cPickle.load(file(key_pkl))
if key != key_from_file:
raise Exception("key not equal to unpickled version (Hint: verify the __eq__ and __hash__ functions for your Ops", (key, key_from_file))
if _version: # save they key
key_pkl = os.path.join(location, 'key.pkl')
key_file = file(key_pkl, 'w')
try:
cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL)
key_file.close()
key_broken = False
except cPickle.PicklingError:
key_file.close()
os.remove(key_pkl)
warning("Cache leak due to unpickle-able key", key)
key_broken = True
if not key_broken:
key_from_file = cPickle.load(file(key_pkl))
if key != key_from_file:
raise Exception("key not equal to unpickled version (Hint: verify the __eq__ and __hash__ functions for your Ops", (key, key_from_file))
self.loaded_key_pkl.add(key_pkl)
self.entry_from_key[key] = name
self.module_from_name[name] = module
self.loaded_key_pkl.add(key_pkl)
self.stats[2] += 1
rval = module
......@@ -448,6 +456,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
preargs.append('-fPIC')
no_opt = False
include_dirs = [distutils.sysconfig.get_python_inc()] + \
numpy.distutils.misc_util.get_numpy_include_dirs()\
+ include_dirs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论