提交 4e5570d6 authored 作者: Ben Mares's avatar Ben Mares 提交者: Ricardo Vieira

Use add_dll_directory as a context manager

上级 5e612abb
......@@ -21,7 +21,7 @@ import textwrap
import time
import warnings
from collections.abc import Callable
from functools import cache
from contextlib import AbstractContextManager, nullcontext
from io import BytesIO, StringIO
from typing import TYPE_CHECKING, Protocol, cast
......@@ -272,15 +272,15 @@ def _get_ext_suffix():
return dist_suffix
@cache # See explanation in docstring.
def add_gcc_dll_directory() -> None:
def add_gcc_dll_directory() -> AbstractContextManager[None]:
"""On Windows, detect and add the location of gcc to the DLL search directory.
On non-Windows platforms this is a noop.
The @cache decorator ensures that this function only executes once to avoid
redundant entries. See <https://github.com/pymc-devs/pytensor/pull/678>.
Returns a context manager to be used with `with`. The entry is removed when the
context manager is closed. See <https://github.com/pymc-devs/pytensor/pull/678>.
"""
cm: AbstractContextManager[None] = nullcontext()
if (sys.platform == "win32") & (hasattr(os, "add_dll_directory")):
gcc_path = shutil.which("gcc")
if gcc_path is not None:
......@@ -288,7 +288,8 @@ def add_gcc_dll_directory() -> None:
# the ignore[attr-defined] on non-Windows platforms.
# For Windows we need ignore[unused-ignore] since the ignore
# is unnecessary with that platform.
os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore]
cm = os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore]
return cm
def dlimport(fullpath, suffix=None):
......@@ -340,20 +341,20 @@ def dlimport(fullpath, suffix=None):
_logger.debug(f"module_name {module_name}")
sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily)
add_gcc_dll_directory()
global import_time
try:
importlib.invalidate_caches()
t0 = time.perf_counter()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
rval = __import__(module_name, {}, {}, [module_name])
t1 = time.perf_counter()
import_time += t1 - t0
if not rval:
raise Exception("__import__ failed", fullpath)
finally:
del sys.path[0]
with add_gcc_dll_directory():
global import_time
try:
importlib.invalidate_caches()
t0 = time.perf_counter()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
rval = __import__(module_name, {}, {}, [module_name])
t1 = time.perf_counter()
import_time += t1 - t0
if not rval:
raise Exception("__import__ failed", fullpath)
finally:
del sys.path[0]
assert fullpath.startswith(rval.__file__)
return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论