提交 b4e95319 authored 作者: Luciano Paz's avatar Luciano Paz

Add Accelerate framework blas__ldflags tests

上级 e73258b4
...@@ -2458,7 +2458,23 @@ class GCC_compiler(Compiler): ...@@ -2458,7 +2458,23 @@ class GCC_compiler(Compiler):
@staticmethod @staticmethod
def linking_patch(lib_dirs: list[str], libs: list[str]) -> list[str]: def linking_patch(lib_dirs: list[str], libs: list[str]) -> list[str]:
if sys.platform != "win32": if sys.platform != "win32":
return [f"-l{l}" for l in libs] patched_libs = []
framework = False
for lib in libs:
# The clang framework flag is handled differently.
# The flag will have the format -framework framework_name
# If we find a lib that is called -framework, we keep it and the following
# entry in the lib list unchanged. Anything else, we add the standard
# -l library prefix.
if lib == "-framework":
framework = True
patched_libs.append(lib)
elif framework:
framework = False
patched_libs.append(lib)
else:
patched_libs.append(f"-l{lib}")
return patched_libs
else: else:
# In explicit else because of https://github.com/python/mypy/issues/10773 # In explicit else because of https://github.com/python/mypy/issues/10773
def sort_key(lib): def sort_key(lib):
...@@ -2466,6 +2482,8 @@ class GCC_compiler(Compiler): ...@@ -2466,6 +2482,8 @@ class GCC_compiler(Compiler):
return (extension == "dll", tuple(map(int, numbers))) return (extension == "dll", tuple(map(int, numbers)))
patched_lib_ldflags = [] patched_lib_ldflags = []
# Should we also add a framework possibility on windows? I didn't do so because
# clang is not intended to be used there at the moment.
for lib in libs: for lib in libs:
ldflag = f"-l{lib}" ldflag = f"-l{lib}"
for lib_dir in lib_dirs: for lib_dir in lib_dirs:
...@@ -2873,9 +2891,21 @@ def default_blas_ldflags(): ...@@ -2873,9 +2891,21 @@ def default_blas_ldflags():
) )
except Exception as e: except Exception as e:
_logger.debug(e) _logger.debug(e)
try:
# 3. Mac Accelerate framework
_logger.debug("Checking Accelerate framework")
flags = ["-framework", "Accelerate"]
if rpath:
flags = [*flags, f"-Wl,-rpath,{rpath}"]
validated_flags = try_blas_flag(flags)
if validated_flags == "":
raise Exception("Accelerate framework flag failed ")
return validated_flags
except Exception as e:
_logger.debug(e)
try: try:
_logger.debug("Checking Lapack + blas") _logger.debug("Checking Lapack + blas")
# 3. Try to use LAPACK + BLAS # 4. Try to use LAPACK + BLAS
return check_libs( return check_libs(
all_libs, all_libs,
required_libs=["lapack", "blas", "cblas", "m"], required_libs=["lapack", "blas", "cblas", "m"],
...@@ -2885,7 +2915,7 @@ def default_blas_ldflags(): ...@@ -2885,7 +2915,7 @@ def default_blas_ldflags():
except Exception as e: except Exception as e:
_logger.debug(e) _logger.debug(e)
try: try:
# 4. Try to use BLAS alone # 5. Try to use BLAS alone
_logger.debug("Checking blas alone") _logger.debug("Checking blas alone")
return check_libs( return check_libs(
all_libs, all_libs,
...@@ -2896,7 +2926,7 @@ def default_blas_ldflags(): ...@@ -2896,7 +2926,7 @@ def default_blas_ldflags():
except Exception as e: except Exception as e:
_logger.debug(e) _logger.debug(e)
try: try:
# 5. Try to use openblas # 6. Try to use openblas
_logger.debug("Checking openblas") _logger.debug("Checking openblas")
return check_libs( return check_libs(
all_libs, all_libs,
......
...@@ -78,7 +78,9 @@ Optimizations associated with these BLAS Ops are in tensor.rewriting.blas ...@@ -78,7 +78,9 @@ Optimizations associated with these BLAS Ops are in tensor.rewriting.blas
import functools import functools
import logging import logging
import os import os
import shlex
import time import time
from pathlib import Path
import numpy as np import numpy as np
...@@ -396,7 +398,7 @@ def _ldflags( ...@@ -396,7 +398,7 @@ def _ldflags(
rval = [] rval = []
if libs_dir: if libs_dir:
found_dyn = False found_dyn = False
dirs = [x[2:] for x in ldflags_str.split() if x.startswith("-L")] dirs = [x[2:] for x in shlex.split(ldflags_str) if x.startswith("-L")]
l = _ldflags( l = _ldflags(
ldflags_str=ldflags_str, ldflags_str=ldflags_str,
libs=True, libs=True,
...@@ -409,6 +411,9 @@ def _ldflags( ...@@ -409,6 +411,9 @@ def _ldflags(
if f.endswith(".so") or f.endswith(".dylib") or f.endswith(".dll"): if f.endswith(".so") or f.endswith(".dylib") or f.endswith(".dll"):
if any(f.find(ll) >= 0 for ll in l): if any(f.find(ll) >= 0 for ll in l):
found_dyn = True found_dyn = True
# Special treatment of clang framework. Specifically for MacOS Accelerate
if "-framework" in l and "Accelerate" in l:
found_dyn = True
if not found_dyn and dirs: if not found_dyn and dirs:
_logger.warning( _logger.warning(
"We did not find a dynamic library in the " "We did not find a dynamic library in the "
...@@ -416,7 +421,12 @@ def _ldflags( ...@@ -416,7 +421,12 @@ def _ldflags(
"ATLAS, make sure to compile it with dynamics library." "ATLAS, make sure to compile it with dynamics library."
) )
for t in ldflags_str.split(): split_flags = shlex.split(ldflags_str)
skip = False
for pos, t in enumerate(split_flags):
if skip:
skip = False
continue
# Remove extra quote. # Remove extra quote.
if (t.startswith("'") and t.endswith("'")) or ( if (t.startswith("'") and t.endswith("'")) or (
t.startswith('"') and t.endswith('"') t.startswith('"') and t.endswith('"')
...@@ -425,10 +435,26 @@ def _ldflags( ...@@ -425,10 +435,26 @@ def _ldflags(
try: try:
t0, t1 = t[0], t[1] t0, t1 = t[0], t[1]
assert t0 == "-" assert t0 == "-" or Path(t).exists()
except Exception: except Exception:
raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"') raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"')
if libs_dir and t1 == "L": if t == "-framework":
skip = True
# Special treatment of clang framework. Specifically for MacOS Accelerate
# The clang framework implicitly adds: header dirs, libraries, and library dirs.
# If we choose to always return these flags, we run into a huge deal amount of
# incompatibilities. For this reason, we only return the framework if libs are
# requested.
if (
libs
and len(split_flags) >= pos
and split_flags[pos + 1] == "Accelerate"
):
# We only add the Accelerate framework, but in the future we could extend it to
# other frameworks
rval.append(t)
rval.append(split_flags[pos + 1])
elif libs_dir and t1 == "L":
rval.append(t[2:]) rval.append(t[2:])
elif include_dir and t1 == "I": elif include_dir and t1 == "I":
raise ValueError( raise ValueError(
......
...@@ -165,13 +165,22 @@ def test_flag_detection(): ...@@ -165,13 +165,22 @@ def test_flag_detection():
@pytest.fixture( @pytest.fixture(
scope="module", scope="module",
params=["mkl_intel", "mkl_gnu", "openblas", "lapack", "blas", "no_blas"], params=[
"mkl_intel",
"mkl_gnu",
"accelerate",
"openblas",
"lapack",
"blas",
"no_blas",
],
) )
def blas_libs(request): def blas_libs(request):
key = request.param key = request.param
libs = { libs = {
"mkl_intel": ["mkl_core", "mkl_rt", "mkl_intel_thread", "iomp5", "pthread"], "mkl_intel": ["mkl_core", "mkl_rt", "mkl_intel_thread", "iomp5", "pthread"],
"mkl_gnu": ["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"], "mkl_gnu": ["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"],
"accelerate": ["vecLib_placeholder"],
"openblas": ["openblas", "gfortran", "gomp", "m"], "openblas": ["openblas", "gfortran", "gomp", "m"],
"lapack": ["lapack", "blas", "cblas", "m"], "lapack": ["lapack", "blas", "cblas", "m"],
"blas": ["blas", "cblas"], "blas": ["blas", "cblas"],
...@@ -190,25 +199,37 @@ def mock_system(request): ...@@ -190,25 +199,37 @@ def mock_system(request):
def cxx_search_dirs(blas_libs, mock_system): def cxx_search_dirs(blas_libs, mock_system):
libext = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"} libext = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}
libraries = [] libraries = []
enabled_accelerate_framework = False
with tempfile.TemporaryDirectory() as d: with tempfile.TemporaryDirectory() as d:
flags = None flags = None
for lib in blas_libs: for lib in blas_libs:
lib_path = Path(d) / f"{lib}.{libext[mock_system]}" if lib == "vecLib_placeholder":
lib_path.write_bytes(b"1") if mock_system != "Darwin":
libraries.append(lib_path) flags = ""
if flags is None: else:
flags = f"-l{lib}" flags = "-framework Accelerate"
enabled_accelerate_framework = True
else: else:
flags += f" -l{lib}" lib_path = Path(d) / f"{lib}.{libext[mock_system]}"
lib_path.write_bytes(b"1")
libraries.append(lib_path)
if flags is None:
flags = f"-l{lib}"
else:
flags += f" -l{lib}"
if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs: if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs:
flags += " -fopenmp" flags += " -fopenmp"
if len(blas_libs) == 0: if len(blas_libs) == 0:
flags = "" flags = ""
yield f"libraries: ={d}".encode(sys.stdout.encoding), flags yield (
f"libraries: ={d}".encode(sys.stdout.encoding),
flags,
enabled_accelerate_framework,
)
@pytest.fixture( @pytest.fixture(
scope="function", params=[False, True], ids=["Working_CXX", "Broken_CXX"] scope="function", params=[True, False], ids=["Working_CXX", "Broken_CXX"]
) )
def cxx_search_dirs_status(request): def cxx_search_dirs_status(request):
return request.param return request.param
...@@ -219,22 +240,39 @@ def cxx_search_dirs_status(request): ...@@ -219,22 +240,39 @@ def cxx_search_dirs_status(request):
def test_default_blas_ldflags( def test_default_blas_ldflags(
mock_std_lib_dirs, mock_check_mkl_openmp, cxx_search_dirs, cxx_search_dirs_status mock_std_lib_dirs, mock_check_mkl_openmp, cxx_search_dirs, cxx_search_dirs_status
): ):
cxx_search_dirs, expected_blas_ldflags = cxx_search_dirs cxx_search_dirs, expected_blas_ldflags, enabled_accelerate_framework = (
cxx_search_dirs
)
mock_process = MagicMock() mock_process = MagicMock()
if cxx_search_dirs_status: if cxx_search_dirs_status:
error_message = "" error_message = ""
mock_process.communicate = lambda *args, **kwargs: (cxx_search_dirs, b"") mock_process.communicate = lambda *args, **kwargs: (cxx_search_dirs, b"")
mock_process.returncode = 0 mock_process.returncode = 0
else: else:
enabled_accelerate_framework = False
error_message = "Unsupported argument -print-search-dirs" error_message = "Unsupported argument -print-search-dirs"
error_message_bytes = error_message.encode(sys.stderr.encoding) error_message_bytes = error_message.encode(sys.stderr.encoding)
mock_process.communicate = lambda *args, **kwargs: (b"", error_message_bytes) mock_process.communicate = lambda *args, **kwargs: (b"", error_message_bytes)
mock_process.returncode = 1 mock_process.returncode = 1
def patched_compile_tmp(*args, **kwargs):
def wrapped(test_code, tmp_prefix, flags, try_run, output):
if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]:
print(enabled_accelerate_framework)
if enabled_accelerate_framework:
return (True, True)
else:
return (False, False, "", "Invalid flags -framework Accelerate")
else:
return (True, True)
return wrapped
with patch("pytensor.link.c.cmodule.subprocess_Popen", return_value=mock_process): with patch("pytensor.link.c.cmodule.subprocess_Popen", return_value=mock_process):
with patch.object( with patch.object(
pytensor.link.c.cmodule.GCC_compiler, pytensor.link.c.cmodule.GCC_compiler,
"try_compile_tmp", "try_compile_tmp",
return_value=(True, True), new_callable=patched_compile_tmp,
): ):
if cxx_search_dirs_status: if cxx_search_dirs_status:
assert set(default_blas_ldflags().split(" ")) == set( assert set(default_blas_ldflags().split(" ")) == set(
...@@ -267,6 +305,9 @@ def windows_conda_libs(blas_libs): ...@@ -267,6 +305,9 @@ def windows_conda_libs(blas_libs):
subdir.mkdir(exist_ok=True, parents=True) subdir.mkdir(exist_ok=True, parents=True)
flags = f'-L"{subdir}"' flags = f'-L"{subdir}"'
for lib in blas_libs: for lib in blas_libs:
if lib == "vecLib_placeholder":
flags = ""
break
lib_path = subdir / f"{lib}.dll" lib_path = subdir / f"{lib}.dll"
lib_path.write_bytes(b"1") lib_path.write_bytes(b"1")
libraries.append(lib_path) libraries.append(lib_path)
...@@ -287,6 +328,16 @@ def test_default_blas_ldflags_conda_windows( ...@@ -287,6 +328,16 @@ def test_default_blas_ldflags_conda_windows(
mock_process = MagicMock() mock_process = MagicMock()
mock_process.communicate = lambda *args, **kwargs: (b"", b"") mock_process.communicate = lambda *args, **kwargs: (b"", b"")
mock_process.returncode = 0 mock_process.returncode = 0
def patched_compile_tmp(*args, **kwargs):
def wrapped(test_code, tmp_prefix, flags, try_run, output):
if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]:
return (False, False, "", "Invalid flags -framework Accelerate")
else:
return (True, True)
return wrapped
with patch("sys.platform", "win32"): with patch("sys.platform", "win32"):
with patch("sys.prefix", mock_sys_prefix): with patch("sys.prefix", mock_sys_prefix):
with patch( with patch(
...@@ -295,7 +346,7 @@ def test_default_blas_ldflags_conda_windows( ...@@ -295,7 +346,7 @@ def test_default_blas_ldflags_conda_windows(
with patch.object( with patch.object(
pytensor.link.c.cmodule.GCC_compiler, pytensor.link.c.cmodule.GCC_compiler,
"try_compile_tmp", "try_compile_tmp",
return_value=(True, True), new_callable=patched_compile_tmp,
): ):
assert set(default_blas_ldflags().split(" ")) == set( assert set(default_blas_ldflags().split(" ")) == set(
expected_blas_ldflags.split(" ") expected_blas_ldflags.split(" ")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论