提交 48f1deb9 authored 作者: lucianopaz's avatar lucianopaz 提交者: Thomas Wiecki

Make default_blas_ldflags to not rely on numpy blas_info

上级 5ad11819
差异被折叠。
...@@ -4,11 +4,11 @@ We don't have real tests for the cache, but it would be great to make them! ...@@ -4,11 +4,11 @@ We don't have real tests for the cache, but it would be great to make them!
But this one tests a current behavior that isn't good: the c_code isn't But this one tests a current behavior that isn't good: the c_code isn't
deterministic based on the input type and the op. deterministic based on the input type and the op.
""" """
import logging
import multiprocessing import multiprocessing
import os import os
import sys
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
...@@ -161,16 +161,69 @@ def test_flag_detection(): ...@@ -161,16 +161,69 @@ def test_flag_detection():
assert isinstance(res, bool) assert isinstance(res, bool)
@patch("pytensor.link.c.cmodule.try_blas_flag", return_value=None) @pytest.fixture(
@patch("pytensor.link.c.cmodule.sys") scope="module",
def test_default_blas_ldflags(sys_mock, try_blas_flag_mock, caplog): params=["mkl_intel", "mkl_gnu", "openblas", "lapack", "blas", "no_blas"],
sys_mock.version = "3.8.0 | packaged by conda-forge | (default, Nov 22 2019, 19:11:38) \n[GCC 7.3.0]" )
def blas_libs(request):
with patch.dict("sys.modules", {"mkl": None}): key = request.param
with caplog.at_level(logging.WARNING): libs = {
default_blas_ldflags() "mkl_intel": ["mkl_core", "mkl_rt", "mkl_intel_thread", "iomp5", "pthread"],
"mkl_gnu": ["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"],
assert caplog.text == "" "openblas": ["openblas", "gfortran", "gomp", "m"],
"lapack": ["lapack", "blas", "cblas", "m"],
"blas": ["blas", "cblas"],
"no_blas": [],
}
return libs[key]
@pytest.fixture(scope="function", params=["Linux", "Windows", "Darwin"])
def mock_system(request):
with patch("platform.system", return_value=request.param):
yield request.param
@pytest.fixture()
def cxx_search_dirs(blas_libs, mock_system):
libext = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}
libtemplate = f"{{lib}}.{libext[mock_system]}"
libraries = []
with tempfile.TemporaryDirectory() as d:
flags = None
for lib in blas_libs:
lib_path = os.path.join(d, libtemplate.format(lib=lib))
with open(lib_path, "wb") as f:
f.write(b"1")
libraries.append(lib_path)
if flags is None:
flags = f"-l{lib}"
else:
flags += f" -l{lib}"
if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs:
flags += " -fopenmp"
if len(blas_libs) == 0:
flags = ""
yield f"libraries: ={d}".encode(sys.stdout.encoding), flags
@patch("pytensor.link.c.cmodule.std_lib_dirs", return_value=[])
@patch("pytensor.link.c.cmodule.check_mkl_openmp", return_value=None)
def test_default_blas_ldflags(
mock_std_lib_dirs, mock_check_mkl_openmp, cxx_search_dirs
):
cxx_search_dirs, expected_blas_ldflags = cxx_search_dirs
mock_process = MagicMock()
mock_process.communicate = lambda *args, **kwargs: (cxx_search_dirs, None)
with patch("pytensor.link.c.cmodule.subprocess_Popen", return_value=mock_process):
with patch.object(
pytensor.link.c.cmodule.GCC_compiler,
"try_compile_tmp",
return_value=(True, True),
):
assert set(default_blas_ldflags().split(" ")) == set(
expected_blas_ldflags.split(" ")
)
@patch( @patch(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论