提交 c028e387 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Increase lock scope in ModuleCache.refresh

上级 9ac18dc1
...@@ -785,263 +785,269 @@ class ModuleCache: ...@@ -785,263 +785,269 @@ class ModuleCache:
except OSError: except OSError:
# This can happen if the dir don't exist. # This can happen if the dir don't exist.
subdirs = [] subdirs = []
files, root = None, None # To make sure the "del" below works
for subdirs_elem in subdirs:
# Never clean/remove lock_dir
if subdirs_elem == "lock_dir":
continue
root = os.path.join(self.dirname, subdirs_elem)
key_pkl = os.path.join(root, "key.pkl")
if key_pkl in self.loaded_key_pkl:
continue
if not os.path.isdir(root):
continue
files = os.listdir(root)
if not files:
rmtree_empty(root, ignore_nocleanup=True, msg="empty dir")
continue
if "delete.me" in files:
rmtree(root, ignore_nocleanup=True, msg="delete.me found in dir")
continue
elif "key.pkl" in files:
try:
entry = module_name_from_dir(root, files=files)
except ValueError: # there is a key but no dll!
if not root.startswith("/tmp"):
# Under /tmp, file are removed periodically by the
# os. So it is normal that this happens from time
# to time.
_logger.warning(
"ModuleCache.refresh() Found key "
f"without dll in cache, deleting it. {key_pkl}",
)
rmtree(
root,
ignore_nocleanup=True,
msg="missing module file",
level=logging.INFO,
)
continue
if (time_now - last_access_time(entry)) < age_thresh_use:
_logger.debug(f"refresh adding {key_pkl}")
def unpickle_failure(): files, root = None, None # To make sure the "del" below works
_logger.info(
f"ModuleCache.refresh() Failed to unpickle cache file {key_pkl}",
)
# Collections used by external (and potentially asynchronous)
# compilation processes are modified in the following loop, so we need
# to lock on the compilation directory so that those processes don't
# work with stale/invalid data
with lock_ctx():
for subdirs_elem in subdirs:
# Never clean/remove lock_dir
if subdirs_elem == "lock_dir":
continue
root = os.path.join(self.dirname, subdirs_elem)
key_pkl = os.path.join(root, "key.pkl")
if key_pkl in self.loaded_key_pkl:
continue
if not os.path.isdir(root):
continue
files = os.listdir(root)
if not files:
rmtree_empty(root, ignore_nocleanup=True, msg="empty dir")
continue
if "delete.me" in files:
rmtree(root, ignore_nocleanup=True, msg="delete.me found in dir")
continue
elif "key.pkl" in files:
try: try:
with open(key_pkl, "rb") as f: entry = module_name_from_dir(root, files=files)
key_data = pickle.load(f) except ValueError: # there is a key but no dll!
except EOFError: if not root.startswith("/tmp"):
# Happened once... not sure why (would be worth # Under /tmp, file are removed periodically by the
# investigating if it ever happens again). # os. So it is normal that this happens from time
unpickle_failure() # to time.
_logger.warning(
"ModuleCache.refresh() Found key "
f"without dll in cache, deleting it. {key_pkl}",
)
rmtree( rmtree(
root, root,
ignore_nocleanup=True, ignore_nocleanup=True,
msg="broken cache directory [EOF]", msg="missing module file",
level=logging.WARNING, level=logging.INFO,
) )
continue continue
except Exception: if (time_now - last_access_time(entry)) < age_thresh_use:
unpickle_failure() _logger.debug(f"refresh adding {key_pkl}")
if delete_if_problem:
def unpickle_failure():
_logger.info(
f"ModuleCache.refresh() Failed to unpickle cache file {key_pkl}",
)
try:
with open(key_pkl, "rb") as f:
key_data = pickle.load(f)
except EOFError:
# Happened once... not sure why (would be worth
# investigating if it ever happens again).
unpickle_failure()
rmtree( rmtree(
root, root,
ignore_nocleanup=True, ignore_nocleanup=True,
msg="broken cache directory", msg="broken cache directory [EOF]",
level=logging.INFO, level=logging.WARNING,
) )
else: continue
# This exception is often triggered by keys except Exception:
# that contain references to classes that have unpickle_failure()
# not yet been imported (e.g. when running two if delete_if_problem:
# different Aesara-based scripts). They are not rmtree(
# necessarily broken, but we cannot load them root,
# now. They will be loaded later if needed. ignore_nocleanup=True,
pass msg="broken cache directory",
continue level=logging.INFO,
)
if not isinstance(key_data, KeyData): else:
# This is some old cache data, that does not fit # This exception is often triggered by keys
# the new cache format. It would be possible to # that contain references to classes that have
# update it, but it is not entirely safe since we # not yet been imported (e.g. when running two
# do not know the config options that were used. # different Aesara-based scripts). They are not
# As a result, we delete it instead (which is also # necessarily broken, but we cannot load them
# simpler to implement). # now. They will be loaded later if needed.
rmtree( pass
root, continue
ignore_nocleanup=True,
msg=(
"invalid cache entry format -- this "
"should not happen unless your cache "
"was really old"
),
level=logging.WARN,
)
continue
# Check the path to the module stored in the KeyData if not isinstance(key_data, KeyData):
# object matches the path to `entry`. There may be # This is some old cache data, that does not fit
# a mismatch e.g. due to symlinks, or some directory # the new cache format. It would be possible to
# being renamed since last time cache was created. # update it, but it is not entirely safe since we
kd_entry = key_data.get_entry() # do not know the config options that were used.
if kd_entry != entry: # As a result, we delete it instead (which is also
if is_same_entry(entry, kd_entry): # simpler to implement).
# Update KeyData object. Note that we also need
# to update the key_pkl field, because it is
# likely to be incorrect if the entry itself
# was wrong.
key_data.entry = entry
key_data.key_pkl = key_pkl
else:
# This is suspicious. Better get rid of it.
rmtree( rmtree(
root, root,
ignore_nocleanup=True, ignore_nocleanup=True,
msg="module file path mismatch", msg=(
level=logging.INFO, "invalid cache entry format -- this "
"should not happen unless your cache "
"was really old"
),
level=logging.WARN,
) )
continue continue
# Find unversioned keys from other processes. # Check the path to the module stored in the KeyData
# TODO: check if this can happen at all # object matches the path to `entry`. There may be
to_del = [key for key in key_data.keys if not key[0]] # a mismatch e.g. due to symlinks, or some directory
if to_del: # being renamed since last time cache was created.
_logger.warning( kd_entry = key_data.get_entry()
"ModuleCache.refresh() Found unversioned " if kd_entry != entry:
f"key in cache, removing it. {key_pkl}", if is_same_entry(entry, kd_entry):
) # Update KeyData object. Note that we also need
# Since the version is in the module hash, all # to update the key_pkl field, because it is
# keys should be unversioned. # likely to be incorrect if the entry itself
if len(to_del) != len(key_data.keys): # was wrong.
_logger.warning( key_data.entry = entry
"Found a mix of unversioned and " key_data.key_pkl = key_pkl
"versioned keys for the same " else:
f"module {key_pkl}", # This is suspicious. Better get rid of it.
)
rmtree(
root,
ignore_nocleanup=True,
msg="unversioned key(s) in cache",
level=logging.INFO,
)
continue
mod_hash = key_data.module_hash
if mod_hash in self.module_hash_to_key_data:
# This may happen when two processes running
# simultaneously compiled the same module, one
# after the other. We delete one once it is old
# enough (to be confident there is no other process
# using it), or if `delete_if_problem` is True.
# Note that it is important to walk through
# directories in alphabetical order so as to make
# sure all new processes only use the first one.
if cleanup:
age = time.time() - last_access_time(entry)
if delete_if_problem or age > self.age_thresh_del:
rmtree( rmtree(
root, root,
ignore_nocleanup=True, ignore_nocleanup=True,
msg="duplicated module", msg="module file path mismatch",
level=logging.DEBUG, level=logging.INFO,
)
else:
_logger.debug(
"Found duplicated module not "
"old enough yet to be deleted "
f"(age: {age}): {entry}",
) )
continue continue
# Remember the map from a module's hash to the KeyData # Find unversioned keys from other processes.
# object associated with it. # TODO: check if this can happen at all
self.module_hash_to_key_data[mod_hash] = key_data to_del = [key for key in key_data.keys if not key[0]]
if to_del:
for key in key_data.keys:
if key not in self.entry_from_key:
self.entry_from_key[key] = entry
# Assert that we have not already got this
# entry somehow.
assert entry not in self.module_from_name
# Store safe part of versioned keys.
if key[0]:
self.similar_keys.setdefault(
get_safe_part(key), []
).append(key)
else:
dir1 = os.path.dirname(self.entry_from_key[key])
dir2 = os.path.dirname(entry)
_logger.warning( _logger.warning(
"The same cache key is associated to " "ModuleCache.refresh() Found unversioned "
f"different modules ({dir1} and {dir2}). This " f"key in cache, removing it. {key_pkl}",
"is not supposed to happen! You may "
"need to manually delete your cache "
"directory to fix this.",
) )
# Clean up the name space to prevent bug. # Since the version is in the module hash, all
if key_data.keys: # keys should be unversioned.
del key if len(to_del) != len(key_data.keys):
self.loaded_key_pkl.add(key_pkl) _logger.warning(
else: "Found a mix of unversioned and "
too_old_to_use.append(entry) "versioned keys for the same "
f"module {key_pkl}",
)
rmtree(
root,
ignore_nocleanup=True,
msg="unversioned key(s) in cache",
level=logging.INFO,
)
continue
# If the compilation failed, no key.pkl is in that mod_hash = key_data.module_hash
# directory, but a mod.* should be there. if mod_hash in self.module_hash_to_key_data:
# We do nothing here. # This may happen when two processes running
# simultaneously compiled the same module, one
# after the other. We delete one once it is old
# enough (to be confident there is no other process
# using it), or if `delete_if_problem` is True.
# Note that it is important to walk through
# directories in alphabetical order so as to make
# sure all new processes only use the first one.
if cleanup:
age = time.time() - last_access_time(entry)
if delete_if_problem or age > self.age_thresh_del:
rmtree(
root,
ignore_nocleanup=True,
msg="duplicated module",
level=logging.DEBUG,
)
else:
_logger.debug(
"Found duplicated module not "
"old enough yet to be deleted "
f"(age: {age}): {entry}",
)
continue
# Clean up the name space to prevent bug. # Remember the map from a module's hash to the KeyData
del root, files, subdirs # object associated with it.
self.module_hash_to_key_data[mod_hash] = key_data
for key in key_data.keys:
if key not in self.entry_from_key:
self.entry_from_key[key] = entry
# Assert that we have not already got this
# entry somehow.
assert entry not in self.module_from_name
# Store safe part of versioned keys.
if key[0]:
self.similar_keys.setdefault(
get_safe_part(key), []
).append(key)
else:
dir1 = os.path.dirname(self.entry_from_key[key])
dir2 = os.path.dirname(entry)
_logger.warning(
"The same cache key is associated to "
f"different modules ({dir1} and {dir2}). This "
"is not supposed to happen! You may "
"need to manually delete your cache "
"directory to fix this.",
)
# Clean up the name space to prevent bug.
if key_data.keys:
del key
self.loaded_key_pkl.add(key_pkl)
else:
too_old_to_use.append(entry)
# Remove entries that are not in the filesystem. # If the compilation failed, no key.pkl is in that
items_copy = list(self.module_hash_to_key_data.items()) # directory, but a mod.* should be there.
for module_hash, key_data in items_copy: # We do nothing here.
entry = key_data.get_entry()
try: # Clean up the name space to prevent bug.
# Test to see that the file is [present and] readable. del root, files, subdirs
open(entry).close()
gone = False # Remove entries that are not in the filesystem.
except OSError: items_copy = list(self.module_hash_to_key_data.items())
gone = True for module_hash, key_data in items_copy:
entry = key_data.get_entry()
if gone: try:
# Assert that we did not have one of the deleted files # Test to see that the file is [present and] readable.
# loaded up and in use. open(entry).close()
# If so, it should not have been deleted. This should be gone = False
# considered a failure of the OTHER process, that deleted except OSError:
# it. gone = True
if entry in self.module_from_name:
_logger.warning( if gone:
"A module that was loaded by this " # Assert that we did not have one of the deleted files
"ModuleCache can no longer be read from file " # loaded up and in use.
f"{entry}... this could lead to problems.", # If so, it should not have been deleted. This should be
) # considered a failure of the OTHER process, that deleted
del self.module_from_name[entry] # it.
if entry in self.module_from_name:
_logger.info(f"deleting ModuleCache entry {entry}")
key_data.delete_keys_from(self.entry_from_key)
del self.module_hash_to_key_data[module_hash]
if key_data.keys and list(key_data.keys)[0][0]:
# this is a versioned entry, so should have been on
# disk. Something weird happened to cause this, so we
# are responding by printing a warning, removing
# evidence that we ever saw this mystery key.
pkl_file_to_remove = key_data.key_pkl
if not key_data.key_pkl.startswith("/tmp"):
# Under /tmp, file are removed periodically by the
# os. So it is normal that this happen from time to
# time.
_logger.warning( _logger.warning(
f"Removing key file {pkl_file_to_remove} because the " "A module that was loaded by this "
"corresponding module is gone from the " "ModuleCache can no longer be read from file "
"file system." f"{entry}... this could lead to problems.",
) )
self.loaded_key_pkl.remove(pkl_file_to_remove) del self.module_from_name[entry]
_logger.info(f"deleting ModuleCache entry {entry}")
key_data.delete_keys_from(self.entry_from_key)
del self.module_hash_to_key_data[module_hash]
if key_data.keys and list(key_data.keys)[0][0]:
# this is a versioned entry, so should have been on
# disk. Something weird happened to cause this, so we
# are responding by printing a warning, removing
# evidence that we ever saw this mystery key.
pkl_file_to_remove = key_data.key_pkl
if not key_data.key_pkl.startswith("/tmp"):
# Under /tmp, file are removed periodically by the
# os. So it is normal that this happen from time to
# time.
_logger.warning(
f"Removing key file {pkl_file_to_remove} because the "
"corresponding module is gone from the "
"file system."
)
self.loaded_key_pkl.remove(pkl_file_to_remove)
if to_delete or to_delete_empty: if to_delete or to_delete_empty:
with lock_ctx():
for a, kw in to_delete: for a, kw in to_delete:
_rmtree(*a, **kw) _rmtree(*a, **kw)
for a, kw in to_delete_empty: for a, kw in to_delete_empty:
...@@ -1049,7 +1055,7 @@ class ModuleCache: ...@@ -1049,7 +1055,7 @@ class ModuleCache:
if not files: if not files:
_rmtree(*a, **kw) _rmtree(*a, **kw)
_logger.debug(f"Time needed to refresh cache: {time.time() - start_time}") _logger.debug(f"Time needed to refresh cache: {time.time() - start_time}")
return too_old_to_use return too_old_to_use
......
...@@ -5,6 +5,7 @@ But this one tests a current behavior that isn't good: the c_code isn't ...@@ -5,6 +5,7 @@ But this one tests a current behavior that isn't good: the c_code isn't
deterministic based on the input type and the op. deterministic based on the input type and the op.
""" """
import logging import logging
import multiprocessing
import os import os
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
...@@ -12,6 +13,8 @@ from unittest.mock import patch ...@@ -12,6 +13,8 @@ from unittest.mock import patch
import numpy as np import numpy as np
import pytest import pytest
import aesara
import aesara.tensor as at
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -139,3 +142,47 @@ def test_linking_patch(listdir_mock, platform): ...@@ -139,3 +142,47 @@ def test_linking_patch(listdir_mock, platform):
"-lmkl_core", "-lmkl_core",
"-lmkl_rt", "-lmkl_rt",
] ]
def test_cache_race_condition():
with tempfile.TemporaryDirectory() as dir_name:
@config.change_flags(on_opt_error="raise", on_shape_error="raise")
def f_build(factor):
# Some of the caching issues arise during constant folding within the
# optimization passes, so we need these config changes to prevent the
# exceptions from being caught
a = at.vector()
f = aesara.function([a], factor * a)
return f(np.array([1], dtype=config.floatX))
ctx = multiprocessing.get_context()
compiledir_prop = aesara.config._config_var_dict["compiledir"]
# The module cache must (initially) be `None` for all processes so that
# `ModuleCache.refresh` is called
with patch.object(compiledir_prop, "val", dir_name, create=True), patch.object(
aesara.link.c.cmodule, "_module_cache", None
):
assert aesara.config.compiledir == dir_name
num_procs = 30
rng = np.random.default_rng(209)
for i in range(10):
# A random, constant input to prevent caching between runs
factor = rng.random()
procs = [
ctx.Process(target=f_build, args=(factor,))
for i in range(num_procs)
]
for proc in procs:
proc.start()
for proc in procs:
proc.join()
assert not any(
exit_code != 0 for exit_code in [proc.exitcode for proc in procs]
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论