提交 b4e95319 authored 作者: Luciano Paz's avatar Luciano Paz

Add Accelerate framework blas__ldflags tests

上级 e73258b4
......@@ -2458,7 +2458,23 @@ class GCC_compiler(Compiler):
@staticmethod
def linking_patch(lib_dirs: list[str], libs: list[str]) -> list[str]:
if sys.platform != "win32":
return [f"-l{l}" for l in libs]
patched_libs = []
framework = False
for lib in libs:
# The clang framework flag is handled differently.
# The flag will have the format -framework framework_name
# If we find a lib that is called -framework, we keep it and the following
# entry in the lib list unchanged. Anything else, we add the standard
# -l library prefix.
if lib == "-framework":
framework = True
patched_libs.append(lib)
elif framework:
framework = False
patched_libs.append(lib)
else:
patched_libs.append(f"-l{lib}")
return patched_libs
else:
# In explicit else because of https://github.com/python/mypy/issues/10773
def sort_key(lib):
......@@ -2466,6 +2482,8 @@ class GCC_compiler(Compiler):
return (extension == "dll", tuple(map(int, numbers)))
patched_lib_ldflags = []
# Should we also add a framework possibility on windows? I didn't do so because
# clang is not intended to be used there at the moment.
for lib in libs:
ldflag = f"-l{lib}"
for lib_dir in lib_dirs:
......@@ -2873,9 +2891,21 @@ def default_blas_ldflags():
)
except Exception as e:
_logger.debug(e)
try:
# 3. Mac Accelerate framework
_logger.debug("Checking Accelerate framework")
flags = ["-framework", "Accelerate"]
if rpath:
flags = [*flags, f"-Wl,-rpath,{rpath}"]
validated_flags = try_blas_flag(flags)
if validated_flags == "":
raise Exception("Accelerate framework flag failed ")
return validated_flags
except Exception as e:
_logger.debug(e)
try:
_logger.debug("Checking Lapack + blas")
# 3. Try to use LAPACK + BLAS
# 4. Try to use LAPACK + BLAS
return check_libs(
all_libs,
required_libs=["lapack", "blas", "cblas", "m"],
......@@ -2885,7 +2915,7 @@ def default_blas_ldflags():
except Exception as e:
_logger.debug(e)
try:
# 4. Try to use BLAS alone
# 5. Try to use BLAS alone
_logger.debug("Checking blas alone")
return check_libs(
all_libs,
......@@ -2896,7 +2926,7 @@ def default_blas_ldflags():
except Exception as e:
_logger.debug(e)
try:
# 5. Try to use openblas
# 6. Try to use openblas
_logger.debug("Checking openblas")
return check_libs(
all_libs,
......
......@@ -78,7 +78,9 @@ Optimizations associated with these BLAS Ops are in tensor.rewriting.blas
import functools
import logging
import os
import shlex
import time
from pathlib import Path
import numpy as np
......@@ -396,7 +398,7 @@ def _ldflags(
rval = []
if libs_dir:
found_dyn = False
dirs = [x[2:] for x in ldflags_str.split() if x.startswith("-L")]
dirs = [x[2:] for x in shlex.split(ldflags_str) if x.startswith("-L")]
l = _ldflags(
ldflags_str=ldflags_str,
libs=True,
......@@ -409,6 +411,9 @@ def _ldflags(
if f.endswith(".so") or f.endswith(".dylib") or f.endswith(".dll"):
if any(f.find(ll) >= 0 for ll in l):
found_dyn = True
# Special treatment of clang framework. Specifically for MacOS Accelerate
if "-framework" in l and "Accelerate" in l:
found_dyn = True
if not found_dyn and dirs:
_logger.warning(
"We did not find a dynamic library in the "
......@@ -416,7 +421,12 @@ def _ldflags(
"ATLAS, make sure to compile it with dynamics library."
)
for t in ldflags_str.split():
split_flags = shlex.split(ldflags_str)
skip = False
for pos, t in enumerate(split_flags):
if skip:
skip = False
continue
# Remove extra quote.
if (t.startswith("'") and t.endswith("'")) or (
t.startswith('"') and t.endswith('"')
......@@ -425,10 +435,26 @@ def _ldflags(
try:
t0, t1 = t[0], t[1]
assert t0 == "-"
assert t0 == "-" or Path(t).exists()
except Exception:
raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"')
if libs_dir and t1 == "L":
if t == "-framework":
skip = True
# Special treatment of clang framework. Specifically for MacOS Accelerate
# The clang framework implicitly adds: header dirs, libraries, and library dirs.
# If we choose to always return these flags, we run into a huge deal amount of
# incompatibilities. For this reason, we only return the framework if libs are
# requested.
if (
libs
and len(split_flags) >= pos
and split_flags[pos + 1] == "Accelerate"
):
# We only add the Accelerate framework, but in the future we could extend it to
# other frameworks
rval.append(t)
rval.append(split_flags[pos + 1])
elif libs_dir and t1 == "L":
rval.append(t[2:])
elif include_dir and t1 == "I":
raise ValueError(
......
......@@ -165,13 +165,22 @@ def test_flag_detection():
@pytest.fixture(
scope="module",
params=["mkl_intel", "mkl_gnu", "openblas", "lapack", "blas", "no_blas"],
params=[
"mkl_intel",
"mkl_gnu",
"accelerate",
"openblas",
"lapack",
"blas",
"no_blas",
],
)
def blas_libs(request):
key = request.param
libs = {
"mkl_intel": ["mkl_core", "mkl_rt", "mkl_intel_thread", "iomp5", "pthread"],
"mkl_gnu": ["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"],
"accelerate": ["vecLib_placeholder"],
"openblas": ["openblas", "gfortran", "gomp", "m"],
"lapack": ["lapack", "blas", "cblas", "m"],
"blas": ["blas", "cblas"],
......@@ -190,9 +199,17 @@ def mock_system(request):
def cxx_search_dirs(blas_libs, mock_system):
libext = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}
libraries = []
enabled_accelerate_framework = False
with tempfile.TemporaryDirectory() as d:
flags = None
for lib in blas_libs:
if lib == "vecLib_placeholder":
if mock_system != "Darwin":
flags = ""
else:
flags = "-framework Accelerate"
enabled_accelerate_framework = True
else:
lib_path = Path(d) / f"{lib}.{libext[mock_system]}"
lib_path.write_bytes(b"1")
libraries.append(lib_path)
......@@ -204,11 +221,15 @@ def cxx_search_dirs(blas_libs, mock_system):
flags += " -fopenmp"
if len(blas_libs) == 0:
flags = ""
yield f"libraries: ={d}".encode(sys.stdout.encoding), flags
yield (
f"libraries: ={d}".encode(sys.stdout.encoding),
flags,
enabled_accelerate_framework,
)
@pytest.fixture(
scope="function", params=[False, True], ids=["Working_CXX", "Broken_CXX"]
scope="function", params=[True, False], ids=["Working_CXX", "Broken_CXX"]
)
def cxx_search_dirs_status(request):
return request.param
......@@ -219,22 +240,39 @@ def cxx_search_dirs_status(request):
def test_default_blas_ldflags(
mock_std_lib_dirs, mock_check_mkl_openmp, cxx_search_dirs, cxx_search_dirs_status
):
cxx_search_dirs, expected_blas_ldflags = cxx_search_dirs
cxx_search_dirs, expected_blas_ldflags, enabled_accelerate_framework = (
cxx_search_dirs
)
mock_process = MagicMock()
if cxx_search_dirs_status:
error_message = ""
mock_process.communicate = lambda *args, **kwargs: (cxx_search_dirs, b"")
mock_process.returncode = 0
else:
enabled_accelerate_framework = False
error_message = "Unsupported argument -print-search-dirs"
error_message_bytes = error_message.encode(sys.stderr.encoding)
mock_process.communicate = lambda *args, **kwargs: (b"", error_message_bytes)
mock_process.returncode = 1
def patched_compile_tmp(*args, **kwargs):
def wrapped(test_code, tmp_prefix, flags, try_run, output):
if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]:
print(enabled_accelerate_framework)
if enabled_accelerate_framework:
return (True, True)
else:
return (False, False, "", "Invalid flags -framework Accelerate")
else:
return (True, True)
return wrapped
with patch("pytensor.link.c.cmodule.subprocess_Popen", return_value=mock_process):
with patch.object(
pytensor.link.c.cmodule.GCC_compiler,
"try_compile_tmp",
return_value=(True, True),
new_callable=patched_compile_tmp,
):
if cxx_search_dirs_status:
assert set(default_blas_ldflags().split(" ")) == set(
......@@ -267,6 +305,9 @@ def windows_conda_libs(blas_libs):
subdir.mkdir(exist_ok=True, parents=True)
flags = f'-L"{subdir}"'
for lib in blas_libs:
if lib == "vecLib_placeholder":
flags = ""
break
lib_path = subdir / f"{lib}.dll"
lib_path.write_bytes(b"1")
libraries.append(lib_path)
......@@ -287,6 +328,16 @@ def test_default_blas_ldflags_conda_windows(
mock_process = MagicMock()
mock_process.communicate = lambda *args, **kwargs: (b"", b"")
mock_process.returncode = 0
def patched_compile_tmp(*args, **kwargs):
def wrapped(test_code, tmp_prefix, flags, try_run, output):
if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]:
return (False, False, "", "Invalid flags -framework Accelerate")
else:
return (True, True)
return wrapped
with patch("sys.platform", "win32"):
with patch("sys.prefix", mock_sys_prefix):
with patch(
......@@ -295,7 +346,7 @@ def test_default_blas_ldflags_conda_windows(
with patch.object(
pytensor.link.c.cmodule.GCC_compiler,
"try_compile_tmp",
return_value=(True, True),
new_callable=patched_compile_tmp,
):
assert set(default_blas_ldflags().split(" ")) == set(
expected_blas_ldflags.split(" ")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论