提交 2bae7b9b authored 作者: abergeron's avatar abergeron

Merge pull request #1741 from nouiz/cache

Fix recompilation of Ops in other repo then Theano
...@@ -975,10 +975,24 @@ class FunctionMaker(object): ...@@ -975,10 +975,24 @@ class FunctionMaker(object):
raise TypeError( raise TypeError(
'profile passed via both "mode" and "profile" arguments') 'profile passed via both "mode" and "profile" arguments')
self.profile = profile = profile or mode_profile self.profile = profile = profile or mode_profile
if profile: if profile or theano.config.cxx:
# We preload the cache here to don't have its timming # This is very important:
# 1) We preload the cache here to don't have its timming
# included in optimization that compile function. # included in optimization that compile function.
theano.gof.cc.get_module_cache() # 2) If other repo that import Theano have Theano ops defined,
# we need to refresh the cache here. Otherwise, their is import
# order problems.
# When device=gpu, we compile during Theano import. This trigger
# the loading of the cache. But unpickling the cache ask that the
# other repos Ops are completly loaded, which isn't always the
# case!
# If a module isn't completly loaded and their unpickling fail,
# it mean it is safe for this function compilation to skip them,
# but not for futur compilation. So reloading the cache at each
# compilation fix this problem.
# 3) This help propagate knowledge of newly compiled module to
# concurrent process.
theano.gof.cc.get_module_cache().refresh()
# Handle the case where inputs and/or outputs is a single Variable (not in a list) # Handle the case where inputs and/or outputs is a single Variable (not in a list)
self.orig_outputs = outputs self.orig_outputs = outputs
unpack_single = False unpack_single = False
......
...@@ -308,11 +308,12 @@ def last_access_time(path): ...@@ -308,11 +308,12 @@ def last_access_time(path):
return os.stat(path)[stat.ST_ATIME] return os.stat(path)[stat.ST_ATIME]
def module_name_from_dir(dirname, err=True): def module_name_from_dir(dirname, err=True, files=None):
""" """
Scan the contents of a cache directory and return full path of the Scan the contents of a cache directory and return full path of the
dynamic lib in it. dynamic lib in it.
""" """
if files is None:
files = os.listdir(dirname) files = os.listdir(dirname)
names = [file for file in files names = [file for file in files
if file.endswith('.so') or file.endswith('.pyd')] if file.endswith('.so') or file.endswith('.pyd')]
...@@ -640,18 +641,21 @@ class ModuleCache(object): ...@@ -640,18 +641,21 @@ class ModuleCache(object):
time_now = time.time() time_now = time.time()
# Go through directories in alphabetical order to ensure consistent # Go through directories in alphabetical order to ensure consistent
# behavior. # behavior.
root_dirs_files = sorted(os.walk(self.dirname), subdirs = sorted(os.listdir(self.dirname))
key=operator.itemgetter(0)) for root in subdirs:
for root, dirs, files in root_dirs_files: root = os.path.join(self.dirname, root)
key_pkl = os.path.join(root, 'key.pkl') key_pkl = os.path.join(root, 'key.pkl')
if key_pkl in self.loaded_key_pkl: if key_pkl in self.loaded_key_pkl:
continue continue
elif 'delete.me' in files or not files: if not os.path.isdir(root):
continue
files = os.listdir(root)
if 'delete.me' in files or not files:
_rmtree(root, ignore_nocleanup=True, _rmtree(root, ignore_nocleanup=True,
msg="delete.me found in dir") msg="delete.me found in dir")
elif 'key.pkl' in files: elif 'key.pkl' in files:
try: try:
entry = module_name_from_dir(root) entry = module_name_from_dir(root, files=files)
except ValueError: # there is a key but no dll! except ValueError: # there is a key but no dll!
if not root.startswith("/tmp"): if not root.startswith("/tmp"):
# Under /tmp, file are removed periodically by the # Under /tmp, file are removed periodically by the
...@@ -814,8 +818,7 @@ class ModuleCache(object): ...@@ -814,8 +818,7 @@ class ModuleCache(object):
# We do nothing here. # We do nothing here.
# Clean up the name space to prevent bug. # Clean up the name space to prevent bug.
if root_dirs_files: del root, files, subdirs
del root, dirs, files
# Remove entries that are not in the filesystem. # Remove entries that are not in the filesystem.
items_copy = list(self.module_hash_to_key_data.iteritems()) items_copy = list(self.module_hash_to_key_data.iteritems())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论