提交 c028e387 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Increase lock scope in ModuleCache.refresh

上级 9ac18dc1
......@@ -785,263 +785,269 @@ class ModuleCache:
except OSError:
# This can happen if the dir don't exist.
subdirs = []
files, root = None, None # To make sure the "del" below works
for subdirs_elem in subdirs:
# Never clean/remove lock_dir
if subdirs_elem == "lock_dir":
continue
root = os.path.join(self.dirname, subdirs_elem)
key_pkl = os.path.join(root, "key.pkl")
if key_pkl in self.loaded_key_pkl:
continue
if not os.path.isdir(root):
continue
files = os.listdir(root)
if not files:
rmtree_empty(root, ignore_nocleanup=True, msg="empty dir")
continue
if "delete.me" in files:
rmtree(root, ignore_nocleanup=True, msg="delete.me found in dir")
continue
elif "key.pkl" in files:
try:
entry = module_name_from_dir(root, files=files)
except ValueError: # there is a key but no dll!
if not root.startswith("/tmp"):
# Under /tmp, file are removed periodically by the
# os. So it is normal that this happens from time
# to time.
_logger.warning(
"ModuleCache.refresh() Found key "
f"without dll in cache, deleting it. {key_pkl}",
)
rmtree(
root,
ignore_nocleanup=True,
msg="missing module file",
level=logging.INFO,
)
continue
if (time_now - last_access_time(entry)) < age_thresh_use:
_logger.debug(f"refresh adding {key_pkl}")
def unpickle_failure():
_logger.info(
f"ModuleCache.refresh() Failed to unpickle cache file {key_pkl}",
)
files, root = None, None # To make sure the "del" below works
# Collections used by external (and potentially asynchronous)
# compilation processes are modified in the following loop, so we need
# to lock on the compilation directory so that those processes don't
# work with stale/invalid data
with lock_ctx():
for subdirs_elem in subdirs:
# Never clean/remove lock_dir
if subdirs_elem == "lock_dir":
continue
root = os.path.join(self.dirname, subdirs_elem)
key_pkl = os.path.join(root, "key.pkl")
if key_pkl in self.loaded_key_pkl:
continue
if not os.path.isdir(root):
continue
files = os.listdir(root)
if not files:
rmtree_empty(root, ignore_nocleanup=True, msg="empty dir")
continue
if "delete.me" in files:
rmtree(root, ignore_nocleanup=True, msg="delete.me found in dir")
continue
elif "key.pkl" in files:
try:
with open(key_pkl, "rb") as f:
key_data = pickle.load(f)
except EOFError:
# Happened once... not sure why (would be worth
# investigating if it ever happens again).
unpickle_failure()
entry = module_name_from_dir(root, files=files)
except ValueError: # there is a key but no dll!
if not root.startswith("/tmp"):
# Under /tmp, file are removed periodically by the
# os. So it is normal that this happens from time
# to time.
_logger.warning(
"ModuleCache.refresh() Found key "
f"without dll in cache, deleting it. {key_pkl}",
)
rmtree(
root,
ignore_nocleanup=True,
msg="broken cache directory [EOF]",
level=logging.WARNING,
msg="missing module file",
level=logging.INFO,
)
continue
except Exception:
unpickle_failure()
if delete_if_problem:
if (time_now - last_access_time(entry)) < age_thresh_use:
_logger.debug(f"refresh adding {key_pkl}")
def unpickle_failure():
_logger.info(
f"ModuleCache.refresh() Failed to unpickle cache file {key_pkl}",
)
try:
with open(key_pkl, "rb") as f:
key_data = pickle.load(f)
except EOFError:
# Happened once... not sure why (would be worth
# investigating if it ever happens again).
unpickle_failure()
rmtree(
root,
ignore_nocleanup=True,
msg="broken cache directory",
level=logging.INFO,
msg="broken cache directory [EOF]",
level=logging.WARNING,
)
else:
# This exception is often triggered by keys
# that contain references to classes that have
# not yet been imported (e.g. when running two
# different Aesara-based scripts). They are not
# necessarily broken, but we cannot load them
# now. They will be loaded later if needed.
pass
continue
if not isinstance(key_data, KeyData):
# This is some old cache data, that does not fit
# the new cache format. It would be possible to
# update it, but it is not entirely safe since we
# do not know the config options that were used.
# As a result, we delete it instead (which is also
# simpler to implement).
rmtree(
root,
ignore_nocleanup=True,
msg=(
"invalid cache entry format -- this "
"should not happen unless your cache "
"was really old"
),
level=logging.WARN,
)
continue
continue
except Exception:
unpickle_failure()
if delete_if_problem:
rmtree(
root,
ignore_nocleanup=True,
msg="broken cache directory",
level=logging.INFO,
)
else:
# This exception is often triggered by keys
# that contain references to classes that have
# not yet been imported (e.g. when running two
# different Aesara-based scripts). They are not
# necessarily broken, but we cannot load them
# now. They will be loaded later if needed.
pass
continue
# Check the path to the module stored in the KeyData
# object matches the path to `entry`. There may be
# a mismatch e.g. due to symlinks, or some directory
# being renamed since last time cache was created.
kd_entry = key_data.get_entry()
if kd_entry != entry:
if is_same_entry(entry, kd_entry):
# Update KeyData object. Note that we also need
# to update the key_pkl field, because it is
# likely to be incorrect if the entry itself
# was wrong.
key_data.entry = entry
key_data.key_pkl = key_pkl
else:
# This is suspicious. Better get rid of it.
if not isinstance(key_data, KeyData):
# This is some old cache data, that does not fit
# the new cache format. It would be possible to
# update it, but it is not entirely safe since we
# do not know the config options that were used.
# As a result, we delete it instead (which is also
# simpler to implement).
rmtree(
root,
ignore_nocleanup=True,
msg="module file path mismatch",
level=logging.INFO,
msg=(
"invalid cache entry format -- this "
"should not happen unless your cache "
"was really old"
),
level=logging.WARN,
)
continue
# Find unversioned keys from other processes.
# TODO: check if this can happen at all
to_del = [key for key in key_data.keys if not key[0]]
if to_del:
_logger.warning(
"ModuleCache.refresh() Found unversioned "
f"key in cache, removing it. {key_pkl}",
)
# Since the version is in the module hash, all
# keys should be unversioned.
if len(to_del) != len(key_data.keys):
_logger.warning(
"Found a mix of unversioned and "
"versioned keys for the same "
f"module {key_pkl}",
)
rmtree(
root,
ignore_nocleanup=True,
msg="unversioned key(s) in cache",
level=logging.INFO,
)
continue
mod_hash = key_data.module_hash
if mod_hash in self.module_hash_to_key_data:
# This may happen when two processes running
# simultaneously compiled the same module, one
# after the other. We delete one once it is old
# enough (to be confident there is no other process
# using it), or if `delete_if_problem` is True.
# Note that it is important to walk through
# directories in alphabetical order so as to make
# sure all new processes only use the first one.
if cleanup:
age = time.time() - last_access_time(entry)
if delete_if_problem or age > self.age_thresh_del:
# Check the path to the module stored in the KeyData
# object matches the path to `entry`. There may be
# a mismatch e.g. due to symlinks, or some directory
# being renamed since last time cache was created.
kd_entry = key_data.get_entry()
if kd_entry != entry:
if is_same_entry(entry, kd_entry):
# Update KeyData object. Note that we also need
# to update the key_pkl field, because it is
# likely to be incorrect if the entry itself
# was wrong.
key_data.entry = entry
key_data.key_pkl = key_pkl
else:
# This is suspicious. Better get rid of it.
rmtree(
root,
ignore_nocleanup=True,
msg="duplicated module",
level=logging.DEBUG,
)
else:
_logger.debug(
"Found duplicated module not "
"old enough yet to be deleted "
f"(age: {age}): {entry}",
msg="module file path mismatch",
level=logging.INFO,
)
continue
continue
# Remember the map from a module's hash to the KeyData
# object associated with it.
self.module_hash_to_key_data[mod_hash] = key_data
for key in key_data.keys:
if key not in self.entry_from_key:
self.entry_from_key[key] = entry
# Assert that we have not already got this
# entry somehow.
assert entry not in self.module_from_name
# Store safe part of versioned keys.
if key[0]:
self.similar_keys.setdefault(
get_safe_part(key), []
).append(key)
else:
dir1 = os.path.dirname(self.entry_from_key[key])
dir2 = os.path.dirname(entry)
# Find unversioned keys from other processes.
# TODO: check if this can happen at all
to_del = [key for key in key_data.keys if not key[0]]
if to_del:
_logger.warning(
"The same cache key is associated to "
f"different modules ({dir1} and {dir2}). This "
"is not supposed to happen! You may "
"need to manually delete your cache "
"directory to fix this.",
"ModuleCache.refresh() Found unversioned "
f"key in cache, removing it. {key_pkl}",
)
# Clean up the name space to prevent bug.
if key_data.keys:
del key
self.loaded_key_pkl.add(key_pkl)
else:
too_old_to_use.append(entry)
# Since the version is in the module hash, all
# keys should be unversioned.
if len(to_del) != len(key_data.keys):
_logger.warning(
"Found a mix of unversioned and "
"versioned keys for the same "
f"module {key_pkl}",
)
rmtree(
root,
ignore_nocleanup=True,
msg="unversioned key(s) in cache",
level=logging.INFO,
)
continue
# If the compilation failed, no key.pkl is in that
# directory, but a mod.* should be there.
# We do nothing here.
mod_hash = key_data.module_hash
if mod_hash in self.module_hash_to_key_data:
# This may happen when two processes running
# simultaneously compiled the same module, one
# after the other. We delete one once it is old
# enough (to be confident there is no other process
# using it), or if `delete_if_problem` is True.
# Note that it is important to walk through
# directories in alphabetical order so as to make
# sure all new processes only use the first one.
if cleanup:
age = time.time() - last_access_time(entry)
if delete_if_problem or age > self.age_thresh_del:
rmtree(
root,
ignore_nocleanup=True,
msg="duplicated module",
level=logging.DEBUG,
)
else:
_logger.debug(
"Found duplicated module not "
"old enough yet to be deleted "
f"(age: {age}): {entry}",
)
continue
# Clean up the name space to prevent bug.
del root, files, subdirs
# Remember the map from a module's hash to the KeyData
# object associated with it.
self.module_hash_to_key_data[mod_hash] = key_data
for key in key_data.keys:
if key not in self.entry_from_key:
self.entry_from_key[key] = entry
# Assert that we have not already got this
# entry somehow.
assert entry not in self.module_from_name
# Store safe part of versioned keys.
if key[0]:
self.similar_keys.setdefault(
get_safe_part(key), []
).append(key)
else:
dir1 = os.path.dirname(self.entry_from_key[key])
dir2 = os.path.dirname(entry)
_logger.warning(
"The same cache key is associated to "
f"different modules ({dir1} and {dir2}). This "
"is not supposed to happen! You may "
"need to manually delete your cache "
"directory to fix this.",
)
# Clean up the name space to prevent bug.
if key_data.keys:
del key
self.loaded_key_pkl.add(key_pkl)
else:
too_old_to_use.append(entry)
# Remove entries that are not in the filesystem.
items_copy = list(self.module_hash_to_key_data.items())
for module_hash, key_data in items_copy:
entry = key_data.get_entry()
try:
# Test to see that the file is [present and] readable.
open(entry).close()
gone = False
except OSError:
gone = True
if gone:
# Assert that we did not have one of the deleted files
# loaded up and in use.
# If so, it should not have been deleted. This should be
# considered a failure of the OTHER process, that deleted
# it.
if entry in self.module_from_name:
_logger.warning(
"A module that was loaded by this "
"ModuleCache can no longer be read from file "
f"{entry}... this could lead to problems.",
)
del self.module_from_name[entry]
_logger.info(f"deleting ModuleCache entry {entry}")
key_data.delete_keys_from(self.entry_from_key)
del self.module_hash_to_key_data[module_hash]
if key_data.keys and list(key_data.keys)[0][0]:
# this is a versioned entry, so should have been on
# disk. Something weird happened to cause this, so we
# are responding by printing a warning, removing
# evidence that we ever saw this mystery key.
pkl_file_to_remove = key_data.key_pkl
if not key_data.key_pkl.startswith("/tmp"):
# Under /tmp, file are removed periodically by the
# os. So it is normal that this happen from time to
# time.
# If the compilation failed, no key.pkl is in that
# directory, but a mod.* should be there.
# We do nothing here.
# Clean up the name space to prevent bug.
del root, files, subdirs
# Remove entries that are not in the filesystem.
items_copy = list(self.module_hash_to_key_data.items())
for module_hash, key_data in items_copy:
entry = key_data.get_entry()
try:
# Test to see that the file is [present and] readable.
open(entry).close()
gone = False
except OSError:
gone = True
if gone:
# Assert that we did not have one of the deleted files
# loaded up and in use.
# If so, it should not have been deleted. This should be
# considered a failure of the OTHER process, that deleted
# it.
if entry in self.module_from_name:
_logger.warning(
f"Removing key file {pkl_file_to_remove} because the "
"corresponding module is gone from the "
"file system."
"A module that was loaded by this "
"ModuleCache can no longer be read from file "
f"{entry}... this could lead to problems.",
)
self.loaded_key_pkl.remove(pkl_file_to_remove)
del self.module_from_name[entry]
_logger.info(f"deleting ModuleCache entry {entry}")
key_data.delete_keys_from(self.entry_from_key)
del self.module_hash_to_key_data[module_hash]
if key_data.keys and list(key_data.keys)[0][0]:
# this is a versioned entry, so should have been on
# disk. Something weird happened to cause this, so we
# are responding by printing a warning, removing
# evidence that we ever saw this mystery key.
pkl_file_to_remove = key_data.key_pkl
if not key_data.key_pkl.startswith("/tmp"):
# Under /tmp, file are removed periodically by the
# os. So it is normal that this happen from time to
# time.
_logger.warning(
f"Removing key file {pkl_file_to_remove} because the "
"corresponding module is gone from the "
"file system."
)
self.loaded_key_pkl.remove(pkl_file_to_remove)
if to_delete or to_delete_empty:
with lock_ctx():
if to_delete or to_delete_empty:
for a, kw in to_delete:
_rmtree(*a, **kw)
for a, kw in to_delete_empty:
......@@ -1049,7 +1055,7 @@ class ModuleCache:
if not files:
_rmtree(*a, **kw)
_logger.debug(f"Time needed to refresh cache: {time.time() - start_time}")
_logger.debug(f"Time needed to refresh cache: {time.time() - start_time}")
return too_old_to_use
......
......@@ -5,6 +5,7 @@ But this one tests a current behavior that isn't good: the c_code isn't
deterministic based on the input type and the op.
"""
import logging
import multiprocessing
import os
import tempfile
from unittest.mock import patch
......@@ -12,6 +13,8 @@ from unittest.mock import patch
import numpy as np
import pytest
import aesara
import aesara.tensor as at
from aesara.compile.function import function
from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config
......@@ -139,3 +142,47 @@ def test_linking_patch(listdir_mock, platform):
"-lmkl_core",
"-lmkl_rt",
]
def test_cache_race_condition():
with tempfile.TemporaryDirectory() as dir_name:
@config.change_flags(on_opt_error="raise", on_shape_error="raise")
def f_build(factor):
# Some of the caching issues arise during constant folding within the
# optimization passes, so we need these config changes to prevent the
# exceptions from being caught
a = at.vector()
f = aesara.function([a], factor * a)
return f(np.array([1], dtype=config.floatX))
ctx = multiprocessing.get_context()
compiledir_prop = aesara.config._config_var_dict["compiledir"]
# The module cache must (initially) be `None` for all processes so that
# `ModuleCache.refresh` is called
with patch.object(compiledir_prop, "val", dir_name, create=True), patch.object(
aesara.link.c.cmodule, "_module_cache", None
):
assert aesara.config.compiledir == dir_name
num_procs = 30
rng = np.random.default_rng(209)
for i in range(10):
# A random, constant input to prevent caching between runs
factor = rng.random()
procs = [
ctx.Process(target=f_build, args=(factor,))
for i in range(num_procs)
]
for proc in procs:
proc.start()
for proc in procs:
proc.join()
assert not any(
exit_code != 0 for exit_code in [proc.exitcode for proc in procs]
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论