提交 c7b8ddfb authored 作者: Michael Osthege's avatar Michael Osthege

Extract local functions and type them

They could get tested independently. This was mostly done to understand their purpose.
上级 5fc74486
......@@ -20,7 +20,7 @@ import tempfile
import textwrap
import time
import warnings
from collections.abc import Callable
from collections.abc import Callable, Collection, Sequence
from contextlib import AbstractContextManager, nullcontext
from io import BytesIO, StringIO
from pathlib import Path
......@@ -2736,35 +2736,12 @@ sure you have the right version you *will* get wrong results.
)
def default_blas_ldflags() -> str:
"""Look for an available BLAS implementation in the system.
This function tries to compile a simple C code that uses the BLAS
if the required files are found in the system.
It sequentially tries to link to the following implementations, until one is found:
1. Intel MKL with Intel OpenMP threading
2. Intel MKL with GNU OpenMP threading
3. Lapack + BLAS
4. BLAS alone
5. OpenBLAS
Returns
-------
blas flags: str
Blas flags needed to link to the BLAS implementation found in the system.
If no BLAS implementation is found, an empty string is returned.
Notes
-----
This function is triggered when `pytensor.config.blas__ldflags` is not given a user
default, and it is first accessed at runtime. It can be rather slow, so it is advised
to cache the results of this function in PYTENSORRC configuration file or
PyTensor environment flags.
"""
def check_required_file(paths, required_regexs):
libs = []
def _check_required_file(
paths: Collection[Path],
required_regexs: Collection[str | re.Pattern[str]],
) -> list[tuple[str, str]]:
"""Select path parents for each required pattern."""
libs: list[tuple[str, str]] = []
for req in required_regexs:
found = False
for path in paths:
......@@ -2778,7 +2755,9 @@ def default_blas_ldflags() -> str:
raise RuntimeError(f"Required file {req} not found")
return libs
def get_cxx_library_dirs():
def _get_cxx_library_dirs() -> list[str]:
"""Query C++ search dirs and return those the existing ones."""
cmd = [config.cxx, "-print-search-dirs"]
p = subprocess_Popen(
cmd,
......@@ -2801,18 +2780,19 @@ def default_blas_ldflags() -> str:
for line in stdout.decode(sys.getdefaultencoding()).splitlines()
if line.startswith("libraries: =")
]
if len(maybe_lib_dirs) > 0:
maybe_lib_dirs = maybe_lib_dirs[0]
return [str(d) for d in maybe_lib_dirs if d.exists() and d.is_dir()]
def check_libs(
all_libs, required_libs, extra_compile_flags=None, cxx_library_dirs=None
) -> str:
if cxx_library_dirs is None:
cxx_library_dirs = []
if extra_compile_flags is None:
extra_compile_flags = []
found_libs = check_required_file(
if not maybe_lib_dirs:
return []
return [str(d) for d in maybe_lib_dirs[0] if d.exists() and d.is_dir()]
def _check_libs(
all_libs: Collection[Path],
required_libs: Collection[str | re.Pattern],
extra_compile_flags: Sequence[str] = (),
cxx_library_dirs: Sequence[str] = (),
) -> str:
"""Assembly library paths and try BLAS flags, returning the flags on success."""
found_libs = _check_required_file(
all_libs,
required_libs,
)
......@@ -2830,10 +2810,13 @@ def default_blas_ldflags() -> str:
flags = (
libdir_ldflags
+ [f"-l{lib_name}" for _, lib_name in found_libs]
+ extra_compile_flags
+ list(extra_compile_flags)
)
res = try_blas_flag(flags)
if res:
if not res:
_logger.debug("Supplied flags '%s' failed to compile", res)
raise RuntimeError(f"Supplied flags {flags} failed to compile")
if any("mkl" in flag for flag in flags):
try:
check_mkl_openmp()
......@@ -2841,9 +2824,34 @@ def default_blas_ldflags() -> str:
_logger.debug(e)
_logger.debug("The following blas flags will be used: '%s'", res)
return res
else:
_logger.debug("Supplied flags '%s' failed to compile", res)
raise RuntimeError(f"Supplied flags {flags} failed to compile")
def default_blas_ldflags() -> str:
"""Look for an available BLAS implementation in the system.
This function tries to compile a simple C code that uses the BLAS
if the required files are found in the system.
It sequentially tries to link to the following implementations, until one is found:
1. Intel MKL with Intel OpenMP threading
2. Intel MKL with GNU OpenMP threading
3. Lapack + BLAS
4. BLAS alone
5. OpenBLAS
Returns
-------
blas flags: str
Blas flags needed to link to the BLAS implementation found in the system.
If no BLAS implementation is found, an empty string is returned.
Notes
-----
This function is triggered when `pytensor.config.blas__ldflags` is not given a user
default, and it is first accessed at runtime. It can be rather slow, so it is advised
to cache the results of this function in PYTENSORRC configuration file or
PyTensor environment flags.
"""
# If no compiler is available we default to empty ldflags
if not config.cxx:
......@@ -2854,7 +2862,7 @@ def default_blas_ldflags() -> str:
else:
rpath = None
cxx_library_dirs = get_cxx_library_dirs()
cxx_library_dirs = _get_cxx_library_dirs()
searched_library_dirs = cxx_library_dirs + _std_lib_dirs
if sys.platform == "win32":
# Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin
......@@ -2884,7 +2892,7 @@ def default_blas_ldflags() -> str:
try:
# 1. Try to use MKL with INTEL OpenMP threading
_logger.debug("Checking MKL flags with intel threading")
return check_libs(
return _check_libs(
all_libs,
required_libs=[
"mkl_core",
......@@ -2901,7 +2909,7 @@ def default_blas_ldflags() -> str:
try:
# 2. Try to use MKL with GNU OpenMP threading
_logger.debug("Checking MKL flags with GNU OpenMP threading")
return check_libs(
return _check_libs(
all_libs,
required_libs=["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"],
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
......@@ -2924,7 +2932,7 @@ def default_blas_ldflags() -> str:
try:
_logger.debug("Checking Lapack + blas")
# 4. Try to use LAPACK + BLAS
return check_libs(
return _check_libs(
all_libs,
required_libs=["lapack", "blas", "cblas", "m"],
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
......@@ -2935,7 +2943,7 @@ def default_blas_ldflags() -> str:
try:
# 5. Try to use BLAS alone
_logger.debug("Checking blas alone")
return check_libs(
return _check_libs(
all_libs,
required_libs=["blas", "cblas"],
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
......@@ -2946,7 +2954,7 @@ def default_blas_ldflags() -> str:
try:
# 6. Try to use openblas
_logger.debug("Checking openblas")
return check_libs(
return _check_libs(
all_libs,
required_libs=["openblas", "gfortran", "gomp", "m"],
extra_compile_flags=["-fopenmp", f"-Wl,-rpath,{rpath}"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论