提交 4e5570d6 authored 作者: Ben Mares's avatar Ben Mares 提交者: Ricardo Vieira

Use add_dll_directory as a context manager

上级 5e612abb
...@@ -21,7 +21,7 @@ import textwrap ...@@ -21,7 +21,7 @@ import textwrap
import time import time
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from functools import cache from contextlib import AbstractContextManager, nullcontext
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import TYPE_CHECKING, Protocol, cast from typing import TYPE_CHECKING, Protocol, cast
...@@ -272,15 +272,15 @@ def _get_ext_suffix(): ...@@ -272,15 +272,15 @@ def _get_ext_suffix():
return dist_suffix return dist_suffix
@cache # See explanation in docstring. def add_gcc_dll_directory() -> AbstractContextManager[None]:
def add_gcc_dll_directory() -> None:
"""On Windows, detect and add the location of gcc to the DLL search directory. """On Windows, detect and add the location of gcc to the DLL search directory.
On non-Windows platforms this is a noop. On non-Windows platforms this is a noop.
The @cache decorator ensures that this function only executes once to avoid Returns a context manager to be used with `with`. The entry is removed when the
redundant entries. See <https://github.com/pymc-devs/pytensor/pull/678>. context manager is closed. See <https://github.com/pymc-devs/pytensor/pull/678>.
""" """
cm: AbstractContextManager[None] = nullcontext()
if (sys.platform == "win32") & (hasattr(os, "add_dll_directory")): if (sys.platform == "win32") & (hasattr(os, "add_dll_directory")):
gcc_path = shutil.which("gcc") gcc_path = shutil.which("gcc")
if gcc_path is not None: if gcc_path is not None:
...@@ -288,7 +288,8 @@ def add_gcc_dll_directory() -> None: ...@@ -288,7 +288,8 @@ def add_gcc_dll_directory() -> None:
# the ignore[attr-defined] on non-Windows platforms. # the ignore[attr-defined] on non-Windows platforms.
# For Windows we need ignore[unused-ignore] since the ignore # For Windows we need ignore[unused-ignore] since the ignore
# is unnecessary with that platform. # is unnecessary with that platform.
os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore] cm = os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore]
return cm
def dlimport(fullpath, suffix=None): def dlimport(fullpath, suffix=None):
...@@ -340,7 +341,7 @@ def dlimport(fullpath, suffix=None): ...@@ -340,7 +341,7 @@ def dlimport(fullpath, suffix=None):
_logger.debug(f"module_name {module_name}") _logger.debug(f"module_name {module_name}")
sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily) sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily)
add_gcc_dll_directory() with add_gcc_dll_directory():
global import_time global import_time
try: try:
importlib.invalidate_caches() importlib.invalidate_caches()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论