提交 c7b8ddfb authored 作者: Michael Osthege's avatar Michael Osthege

Extract local functions and type them

They could get tested independently. This was mostly done to understand their purpose.
上级 5fc74486
...@@ -20,7 +20,7 @@ import tempfile ...@@ -20,7 +20,7 @@ import tempfile
import textwrap import textwrap
import time import time
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable, Collection, Sequence
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from io import BytesIO, StringIO from io import BytesIO, StringIO
from pathlib import Path from pathlib import Path
...@@ -2736,6 +2736,96 @@ sure you have the right version you *will* get wrong results. ...@@ -2736,6 +2736,96 @@ sure you have the right version you *will* get wrong results.
) )
def _check_required_file(
paths: Collection[Path],
required_regexs: Collection[str | re.Pattern[str]],
) -> list[tuple[str, str]]:
"""Select path parents for each required pattern."""
libs: list[tuple[str, str]] = []
for req in required_regexs:
found = False
for path in paths:
m = re.search(req, path.name)
if m:
libs.append((str(path.parent), m.string[slice(*m.span())]))
found = True
break
if not found:
_logger.debug("Required file '%s' not found", req)
raise RuntimeError(f"Required file {req} not found")
return libs
def _get_cxx_library_dirs() -> list[str]:
"""Query C++ search dirs and return those the existing ones."""
cmd = [config.cxx, "-print-search-dirs"]
p = subprocess_Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.PIPE,
)
(stdout, stderr) = p.communicate(input=b"")
if p.returncode != 0:
warnings.warn(
"Pytensor cxx failed to communicate its search dirs. As a consequence, "
"it might not be possible to automatically determine the blas link flags to use.\n"
f"Command that was run: {config.cxx} -print-search-dirs\n"
f"Output printed to stderr: {stderr.decode(sys.stderr.encoding)}"
)
return []
maybe_lib_dirs = [
[Path(p).resolve() for p in line[len("libraries: =") :].split(":")]
for line in stdout.decode(sys.getdefaultencoding()).splitlines()
if line.startswith("libraries: =")
]
if not maybe_lib_dirs:
return []
return [str(d) for d in maybe_lib_dirs[0] if d.exists() and d.is_dir()]
def _check_libs(
all_libs: Collection[Path],
required_libs: Collection[str | re.Pattern],
extra_compile_flags: Sequence[str] = (),
cxx_library_dirs: Sequence[str] = (),
) -> str:
"""Assembly library paths and try BLAS flags, returning the flags on success."""
found_libs = _check_required_file(
all_libs,
required_libs,
)
path_quote = '"' if sys.platform == "win32" else ""
libdir_ldflags = list(
dict.fromkeys(
[
f"-L{path_quote}{lib_path}{path_quote}"
for lib_path, _ in found_libs
if lib_path not in cxx_library_dirs
]
)
)
flags = (
libdir_ldflags
+ [f"-l{lib_name}" for _, lib_name in found_libs]
+ list(extra_compile_flags)
)
res = try_blas_flag(flags)
if not res:
_logger.debug("Supplied flags '%s' failed to compile", res)
raise RuntimeError(f"Supplied flags {flags} failed to compile")
if any("mkl" in flag for flag in flags):
try:
check_mkl_openmp()
except Exception as e:
_logger.debug(e)
_logger.debug("The following blas flags will be used: '%s'", res)
return res
def default_blas_ldflags() -> str: def default_blas_ldflags() -> str:
"""Look for an available BLAS implementation in the system. """Look for an available BLAS implementation in the system.
...@@ -2763,88 +2853,6 @@ def default_blas_ldflags() -> str: ...@@ -2763,88 +2853,6 @@ def default_blas_ldflags() -> str:
""" """
def check_required_file(paths, required_regexs):
libs = []
for req in required_regexs:
found = False
for path in paths:
m = re.search(req, path.name)
if m:
libs.append((str(path.parent), m.string[slice(*m.span())]))
found = True
break
if not found:
_logger.debug("Required file '%s' not found", req)
raise RuntimeError(f"Required file {req} not found")
return libs
def get_cxx_library_dirs():
cmd = [config.cxx, "-print-search-dirs"]
p = subprocess_Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.PIPE,
)
(stdout, stderr) = p.communicate(input=b"")
if p.returncode != 0:
warnings.warn(
"Pytensor cxx failed to communicate its search dirs. As a consequence, "
"it might not be possible to automatically determine the blas link flags to use.\n"
f"Command that was run: {config.cxx} -print-search-dirs\n"
f"Output printed to stderr: {stderr.decode(sys.stderr.encoding)}"
)
return []
maybe_lib_dirs = [
[Path(p).resolve() for p in line[len("libraries: =") :].split(":")]
for line in stdout.decode(sys.getdefaultencoding()).splitlines()
if line.startswith("libraries: =")
]
if len(maybe_lib_dirs) > 0:
maybe_lib_dirs = maybe_lib_dirs[0]
return [str(d) for d in maybe_lib_dirs if d.exists() and d.is_dir()]
def check_libs(
all_libs, required_libs, extra_compile_flags=None, cxx_library_dirs=None
) -> str:
if cxx_library_dirs is None:
cxx_library_dirs = []
if extra_compile_flags is None:
extra_compile_flags = []
found_libs = check_required_file(
all_libs,
required_libs,
)
path_quote = '"' if sys.platform == "win32" else ""
libdir_ldflags = list(
dict.fromkeys(
[
f"-L{path_quote}{lib_path}{path_quote}"
for lib_path, _ in found_libs
if lib_path not in cxx_library_dirs
]
)
)
flags = (
libdir_ldflags
+ [f"-l{lib_name}" for _, lib_name in found_libs]
+ extra_compile_flags
)
res = try_blas_flag(flags)
if res:
if any("mkl" in flag for flag in flags):
try:
check_mkl_openmp()
except Exception as e:
_logger.debug(e)
_logger.debug("The following blas flags will be used: '%s'", res)
return res
else:
_logger.debug("Supplied flags '%s' failed to compile", res)
raise RuntimeError(f"Supplied flags {flags} failed to compile")
# If no compiler is available we default to empty ldflags # If no compiler is available we default to empty ldflags
if not config.cxx: if not config.cxx:
return "" return ""
...@@ -2854,7 +2862,7 @@ def default_blas_ldflags() -> str: ...@@ -2854,7 +2862,7 @@ def default_blas_ldflags() -> str:
else: else:
rpath = None rpath = None
cxx_library_dirs = get_cxx_library_dirs() cxx_library_dirs = _get_cxx_library_dirs()
searched_library_dirs = cxx_library_dirs + _std_lib_dirs searched_library_dirs = cxx_library_dirs + _std_lib_dirs
if sys.platform == "win32": if sys.platform == "win32":
# Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin # Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin
...@@ -2884,7 +2892,7 @@ def default_blas_ldflags() -> str: ...@@ -2884,7 +2892,7 @@ def default_blas_ldflags() -> str:
try: try:
# 1. Try to use MKL with INTEL OpenMP threading # 1. Try to use MKL with INTEL OpenMP threading
_logger.debug("Checking MKL flags with intel threading") _logger.debug("Checking MKL flags with intel threading")
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=[ required_libs=[
"mkl_core", "mkl_core",
...@@ -2901,7 +2909,7 @@ def default_blas_ldflags() -> str: ...@@ -2901,7 +2909,7 @@ def default_blas_ldflags() -> str:
try: try:
# 2. Try to use MKL with GNU OpenMP threading # 2. Try to use MKL with GNU OpenMP threading
_logger.debug("Checking MKL flags with GNU OpenMP threading") _logger.debug("Checking MKL flags with GNU OpenMP threading")
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"], required_libs=["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"],
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [], extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
...@@ -2924,7 +2932,7 @@ def default_blas_ldflags() -> str: ...@@ -2924,7 +2932,7 @@ def default_blas_ldflags() -> str:
try: try:
_logger.debug("Checking Lapack + blas") _logger.debug("Checking Lapack + blas")
# 4. Try to use LAPACK + BLAS # 4. Try to use LAPACK + BLAS
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=["lapack", "blas", "cblas", "m"], required_libs=["lapack", "blas", "cblas", "m"],
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [], extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
...@@ -2935,7 +2943,7 @@ def default_blas_ldflags() -> str: ...@@ -2935,7 +2943,7 @@ def default_blas_ldflags() -> str:
try: try:
# 5. Try to use BLAS alone # 5. Try to use BLAS alone
_logger.debug("Checking blas alone") _logger.debug("Checking blas alone")
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=["blas", "cblas"], required_libs=["blas", "cblas"],
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [], extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
...@@ -2946,7 +2954,7 @@ def default_blas_ldflags() -> str: ...@@ -2946,7 +2954,7 @@ def default_blas_ldflags() -> str:
try: try:
# 6. Try to use openblas # 6. Try to use openblas
_logger.debug("Checking openblas") _logger.debug("Checking openblas")
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=["openblas", "gfortran", "gomp", "m"], required_libs=["openblas", "gfortran", "gomp", "m"],
extra_compile_flags=["-fopenmp", f"-Wl,-rpath,{rpath}"] extra_compile_flags=["-fopenmp", f"-Wl,-rpath,{rpath}"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论