提交 c028e387 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Increase lock scope in ModuleCache.refresh

上级 9ac18dc1
差异被折叠。
...@@ -5,6 +5,7 @@ But this one tests a current behavior that isn't good: the c_code isn't ...@@ -5,6 +5,7 @@ But this one tests a current behavior that isn't good: the c_code isn't
deterministic based on the input type and the op. deterministic based on the input type and the op.
""" """
import logging import logging
import multiprocessing
import os import os
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
...@@ -12,6 +13,8 @@ from unittest.mock import patch ...@@ -12,6 +13,8 @@ from unittest.mock import patch
import numpy as np import numpy as np
import pytest import pytest
import aesara
import aesara.tensor as at
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -139,3 +142,47 @@ def test_linking_patch(listdir_mock, platform): ...@@ -139,3 +142,47 @@ def test_linking_patch(listdir_mock, platform):
"-lmkl_core", "-lmkl_core",
"-lmkl_rt", "-lmkl_rt",
] ]
def test_cache_race_condition():
with tempfile.TemporaryDirectory() as dir_name:
@config.change_flags(on_opt_error="raise", on_shape_error="raise")
def f_build(factor):
# Some of the caching issues arise during constant folding within the
# optimization passes, so we need these config changes to prevent the
# exceptions from being caught
a = at.vector()
f = aesara.function([a], factor * a)
return f(np.array([1], dtype=config.floatX))
ctx = multiprocessing.get_context()
compiledir_prop = aesara.config._config_var_dict["compiledir"]
# The module cache must (initially) be `None` for all processes so that
# `ModuleCache.refresh` is called
with patch.object(compiledir_prop, "val", dir_name, create=True), patch.object(
aesara.link.c.cmodule, "_module_cache", None
):
assert aesara.config.compiledir == dir_name
num_procs = 30
rng = np.random.default_rng(209)
for i in range(10):
# A random, constant input to prevent caching between runs
factor = rng.random()
procs = [
ctx.Process(target=f_build, args=(factor,))
for i in range(num_procs)
]
for proc in procs:
proc.start()
for proc in procs:
proc.join()
assert not any(
exit_code != 0 for exit_code in [proc.exitcode for proc in procs]
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论