提交 c7b8ddfb authored 作者: Michael Osthege's avatar Michael Osthege

Extract local functions and type them

They could get tested independently. This was mostly done to understand their purpose.
上级 5fc74486
...@@ -20,7 +20,7 @@ import tempfile ...@@ -20,7 +20,7 @@ import tempfile
import textwrap import textwrap
import time import time
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable, Collection, Sequence
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from io import BytesIO, StringIO from io import BytesIO, StringIO
from pathlib import Path from pathlib import Path
...@@ -2736,35 +2736,12 @@ sure you have the right version you *will* get wrong results. ...@@ -2736,35 +2736,12 @@ sure you have the right version you *will* get wrong results.
) )
def default_blas_ldflags() -> str: def _check_required_file(
"""Look for an available BLAS implementation in the system. paths: Collection[Path],
required_regexs: Collection[str | re.Pattern[str]],
This function tries to compile a simple C code that uses the BLAS ) -> list[tuple[str, str]]:
if the required files are found in the system. """Select path parents for each required pattern."""
It sequentially tries to link to the following implementations, until one is found: libs: list[tuple[str, str]] = []
1. Intel MKL with Intel OpenMP threading
2. Intel MKL with GNU OpenMP threading
3. Lapack + BLAS
4. BLAS alone
5. OpenBLAS
Returns
-------
blas flags: str
Blas flags needed to link to the BLAS implementation found in the system.
If no BLAS implementation is found, an empty string is returned.
Notes
-----
This function is triggered when `pytensor.config.blas__ldflags` is not given a user
default, and it is first accessed at runtime. It can be rather slow, so it is advised
to cache the results of this function in PYTENSORRC configuration file or
PyTensor environment flags.
"""
def check_required_file(paths, required_regexs):
libs = []
for req in required_regexs: for req in required_regexs:
found = False found = False
for path in paths: for path in paths:
...@@ -2778,7 +2755,9 @@ def default_blas_ldflags() -> str: ...@@ -2778,7 +2755,9 @@ def default_blas_ldflags() -> str:
raise RuntimeError(f"Required file {req} not found") raise RuntimeError(f"Required file {req} not found")
return libs return libs
def get_cxx_library_dirs():
def _get_cxx_library_dirs() -> list[str]:
"""Query C++ search dirs and return those the existing ones."""
cmd = [config.cxx, "-print-search-dirs"] cmd = [config.cxx, "-print-search-dirs"]
p = subprocess_Popen( p = subprocess_Popen(
cmd, cmd,
...@@ -2801,18 +2780,19 @@ def default_blas_ldflags() -> str: ...@@ -2801,18 +2780,19 @@ def default_blas_ldflags() -> str:
for line in stdout.decode(sys.getdefaultencoding()).splitlines() for line in stdout.decode(sys.getdefaultencoding()).splitlines()
if line.startswith("libraries: =") if line.startswith("libraries: =")
] ]
if len(maybe_lib_dirs) > 0: if not maybe_lib_dirs:
maybe_lib_dirs = maybe_lib_dirs[0] return []
return [str(d) for d in maybe_lib_dirs if d.exists() and d.is_dir()] return [str(d) for d in maybe_lib_dirs[0] if d.exists() and d.is_dir()]
def check_libs(
all_libs, required_libs, extra_compile_flags=None, cxx_library_dirs=None def _check_libs(
) -> str: all_libs: Collection[Path],
if cxx_library_dirs is None: required_libs: Collection[str | re.Pattern],
cxx_library_dirs = [] extra_compile_flags: Sequence[str] = (),
if extra_compile_flags is None: cxx_library_dirs: Sequence[str] = (),
extra_compile_flags = [] ) -> str:
found_libs = check_required_file( """Assembly library paths and try BLAS flags, returning the flags on success."""
found_libs = _check_required_file(
all_libs, all_libs,
required_libs, required_libs,
) )
...@@ -2830,10 +2810,13 @@ def default_blas_ldflags() -> str: ...@@ -2830,10 +2810,13 @@ def default_blas_ldflags() -> str:
flags = ( flags = (
libdir_ldflags libdir_ldflags
+ [f"-l{lib_name}" for _, lib_name in found_libs] + [f"-l{lib_name}" for _, lib_name in found_libs]
+ extra_compile_flags + list(extra_compile_flags)
) )
res = try_blas_flag(flags) res = try_blas_flag(flags)
if res: if not res:
_logger.debug("Supplied flags '%s' failed to compile", res)
raise RuntimeError(f"Supplied flags {flags} failed to compile")
if any("mkl" in flag for flag in flags): if any("mkl" in flag for flag in flags):
try: try:
check_mkl_openmp() check_mkl_openmp()
...@@ -2841,9 +2824,34 @@ def default_blas_ldflags() -> str: ...@@ -2841,9 +2824,34 @@ def default_blas_ldflags() -> str:
_logger.debug(e) _logger.debug(e)
_logger.debug("The following blas flags will be used: '%s'", res) _logger.debug("The following blas flags will be used: '%s'", res)
return res return res
else:
_logger.debug("Supplied flags '%s' failed to compile", res)
raise RuntimeError(f"Supplied flags {flags} failed to compile") def default_blas_ldflags() -> str:
"""Look for an available BLAS implementation in the system.
This function tries to compile a simple C code that uses the BLAS
if the required files are found in the system.
It sequentially tries to link to the following implementations, until one is found:
1. Intel MKL with Intel OpenMP threading
2. Intel MKL with GNU OpenMP threading
3. Lapack + BLAS
4. BLAS alone
5. OpenBLAS
Returns
-------
blas flags: str
Blas flags needed to link to the BLAS implementation found in the system.
If no BLAS implementation is found, an empty string is returned.
Notes
-----
This function is triggered when `pytensor.config.blas__ldflags` is not given a user
default, and it is first accessed at runtime. It can be rather slow, so it is advised
to cache the results of this function in PYTENSORRC configuration file or
PyTensor environment flags.
"""
# If no compiler is available we default to empty ldflags # If no compiler is available we default to empty ldflags
if not config.cxx: if not config.cxx:
...@@ -2854,7 +2862,7 @@ def default_blas_ldflags() -> str: ...@@ -2854,7 +2862,7 @@ def default_blas_ldflags() -> str:
else: else:
rpath = None rpath = None
cxx_library_dirs = get_cxx_library_dirs() cxx_library_dirs = _get_cxx_library_dirs()
searched_library_dirs = cxx_library_dirs + _std_lib_dirs searched_library_dirs = cxx_library_dirs + _std_lib_dirs
if sys.platform == "win32": if sys.platform == "win32":
# Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin # Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin
...@@ -2884,7 +2892,7 @@ def default_blas_ldflags() -> str: ...@@ -2884,7 +2892,7 @@ def default_blas_ldflags() -> str:
try: try:
# 1. Try to use MKL with INTEL OpenMP threading # 1. Try to use MKL with INTEL OpenMP threading
_logger.debug("Checking MKL flags with intel threading") _logger.debug("Checking MKL flags with intel threading")
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=[ required_libs=[
"mkl_core", "mkl_core",
...@@ -2901,7 +2909,7 @@ def default_blas_ldflags() -> str: ...@@ -2901,7 +2909,7 @@ def default_blas_ldflags() -> str:
try: try:
# 2. Try to use MKL with GNU OpenMP threading # 2. Try to use MKL with GNU OpenMP threading
_logger.debug("Checking MKL flags with GNU OpenMP threading") _logger.debug("Checking MKL flags with GNU OpenMP threading")
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"], required_libs=["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"],
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [], extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
...@@ -2924,7 +2932,7 @@ def default_blas_ldflags() -> str: ...@@ -2924,7 +2932,7 @@ def default_blas_ldflags() -> str:
try: try:
_logger.debug("Checking Lapack + blas") _logger.debug("Checking Lapack + blas")
# 4. Try to use LAPACK + BLAS # 4. Try to use LAPACK + BLAS
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=["lapack", "blas", "cblas", "m"], required_libs=["lapack", "blas", "cblas", "m"],
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [], extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
...@@ -2935,7 +2943,7 @@ def default_blas_ldflags() -> str: ...@@ -2935,7 +2943,7 @@ def default_blas_ldflags() -> str:
try: try:
# 5. Try to use BLAS alone # 5. Try to use BLAS alone
_logger.debug("Checking blas alone") _logger.debug("Checking blas alone")
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=["blas", "cblas"], required_libs=["blas", "cblas"],
extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [], extra_compile_flags=[f"-Wl,-rpath,{rpath}"] if rpath is not None else [],
...@@ -2946,7 +2954,7 @@ def default_blas_ldflags() -> str: ...@@ -2946,7 +2954,7 @@ def default_blas_ldflags() -> str:
try: try:
# 6. Try to use openblas # 6. Try to use openblas
_logger.debug("Checking openblas") _logger.debug("Checking openblas")
return check_libs( return _check_libs(
all_libs, all_libs,
required_libs=["openblas", "gfortran", "gomp", "m"], required_libs=["openblas", "gfortran", "gomp", "m"],
extra_compile_flags=["-fopenmp", f"-Wl,-rpath,{rpath}"] extra_compile_flags=["-fopenmp", f"-Wl,-rpath,{rpath}"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论