提交 32af31fb authored 作者: lucianopaz's avatar lucianopaz 提交者: Ricardo Vieira

Include conda Library bin and lib paths to default blas ldflags search dirs

上级 91d3b7c0
...@@ -2744,7 +2744,9 @@ def default_blas_ldflags(): ...@@ -2744,7 +2744,9 @@ def default_blas_ldflags():
[pathlib.Path(p).resolve() for p in line[len("libraries: =") :].split(":")] [pathlib.Path(p).resolve() for p in line[len("libraries: =") :].split(":")]
for line in stdout.decode(sys.stdout.encoding).splitlines() for line in stdout.decode(sys.stdout.encoding).splitlines()
if line.startswith("libraries: =") if line.startswith("libraries: =")
][0] ]
if len(maybe_lib_dirs) > 0:
maybe_lib_dirs = maybe_lib_dirs[0]
return [str(d) for d in maybe_lib_dirs if d.exists() and d.is_dir()] return [str(d) for d in maybe_lib_dirs if d.exists() and d.is_dir()]
def check_libs( def check_libs(
...@@ -2793,6 +2795,13 @@ def default_blas_ldflags(): ...@@ -2793,6 +2795,13 @@ def default_blas_ldflags():
cxx_library_dirs = get_cxx_library_dirs() cxx_library_dirs = get_cxx_library_dirs()
searched_library_dirs = cxx_library_dirs + _std_lib_dirs searched_library_dirs = cxx_library_dirs + _std_lib_dirs
if sys.platform == "win32":
# Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin
# From the conda manual (https://docs.conda.io/projects/conda-build/en/stable/user-guide/environment-variables.html)
# it seems like conda could also save some libraries into the CONDA_PREFIX\Library\lib
# directory. We will include both in our searched library dirs
searched_library_dirs.append(os.path.join(sys.prefix, "Library", "bin"))
searched_library_dirs.append(os.path.join(sys.prefix, "Library", "lib"))
all_libs = [ all_libs = [
l l
for path in [ for path in [
......
...@@ -260,6 +260,51 @@ def test_default_blas_ldflags_no_cxx(): ...@@ -260,6 +260,51 @@ def test_default_blas_ldflags_no_cxx():
assert default_blas_ldflags() == "" assert default_blas_ldflags() == ""
@pytest.fixture()
def windows_conda_libs(blas_libs):
libtemplate = "{lib}.dll"
libraries = []
with tempfile.TemporaryDirectory() as d:
subdir = os.path.join(d, "Library", "bin")
os.makedirs(subdir, exist_ok=True)
flags = f'-L"{subdir}"'
for lib in blas_libs:
lib_path = os.path.join(subdir, libtemplate.format(lib=lib))
with open(lib_path, "wb") as f:
f.write(b"1")
libraries.append(lib_path)
flags += f" -l{lib}"
if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs:
flags += " -fopenmp"
if len(blas_libs) == 0:
flags = ""
yield d, flags
@patch("pytensor.link.c.cmodule.std_lib_dirs", return_value=[])
@patch("pytensor.link.c.cmodule.check_mkl_openmp", return_value=None)
def test_default_blas_ldflags_conda_windows(
mock_std_lib_dirs, mock_check_mkl_openmp, windows_conda_libs
):
mock_sys_prefix, expected_blas_ldflags = windows_conda_libs
mock_process = MagicMock()
mock_process.communicate = lambda *args, **kwargs: (b"", b"")
mock_process.returncode = 0
with patch("sys.platform", "win32"):
with patch("sys.prefix", mock_sys_prefix):
with patch(
"pytensor.link.c.cmodule.subprocess_Popen", return_value=mock_process
):
with patch.object(
pytensor.link.c.cmodule.GCC_compiler,
"try_compile_tmp",
return_value=(True, True),
):
assert set(default_blas_ldflags().split(" ")) == set(
expected_blas_ldflags.split(" ")
)
@patch( @patch(
"os.listdir", return_value=["mkl_core.1.dll", "mkl_rt.1.0.dll", "mkl_rt.1.1.lib"] "os.listdir", return_value=["mkl_core.1.dll", "mkl_rt.1.0.dll", "mkl_rt.1.1.lib"]
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论