提交 1901d980 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make sure MKL uses GNU OpenMP.

上级 b6e37683
...@@ -1238,6 +1238,29 @@ AddConfigVar('cmodule.debug', ...@@ -1238,6 +1238,29 @@ AddConfigVar('cmodule.debug',
in_c_key=True) in_c_key=True)
def check_mkl_openmp():
if not theano.config.blas.check_openmp:
return
import os
if ('MKL_THREADING_LAYER' in os.environ and
os.environ['MKL_THREADING_LAYER'] == 'GNU'):
return
try:
import mkl
if '2018' in mkl.get_version_string():
raise RuntimeError('To use MKL 2018 with Theano you MUST set "MKL_THREADING_LAYER=GNU" in your environement.')
except ImportError:
raise RuntimeError("""
Could not import 'mkl'. Either install mkl-service with conda or set
MKL_THREADING_LAYER=GNU in your environment for MKL 2018.
If you have MKL 2017 install and are not in a conda environment you
can set the Theano flag blas.check_openmp to False. Be warned that if
you set this flag and don't set the appropriate environment or make
sure you have the right version you *will* get wrong results.
""")
def default_blas_ldflags(): def default_blas_ldflags():
global numpy global numpy
warn_record = [] warn_record = []
...@@ -1366,16 +1389,22 @@ def default_blas_ldflags(): ...@@ -1366,16 +1389,22 @@ def default_blas_ldflags():
flags = [] flags = []
if lib_path: if lib_path:
flags = ['-L%s' % lib_path[0]] flags = ['-L%s' % lib_path[0]]
if '2018' in mkl.get_version_string():
thr = 'mkl_gnu_thread'
else:
thr = 'mkl_intel_thread'
flags += ['-l%s' % l for l in ["mkl_core", flags += ['-l%s' % l for l in ["mkl_core",
"mkl_intel_thread", thr,
"mkl_rt"]] "mkl_rt"]]
res = try_blas_flag(flags) res = try_blas_flag(flags)
if res: if res:
check_mkl_openmp()
return res return res
flags.extend(['-Wl,-rpath,' + l for l in flags.extend(['-Wl,-rpath,' + l for l in
blas_info.get('library_dirs', [])]) blas_info.get('library_dirs', [])])
res = try_blas_flag(flags) res = try_blas_flag(flags)
if res: if res:
check_mkl_openmp()
maybe_add_to_os_environ_pathlist('PATH', lib_path[0]) maybe_add_to_os_environ_pathlist('PATH', lib_path[0])
return res return res
...@@ -1396,6 +1425,8 @@ def default_blas_ldflags(): ...@@ -1396,6 +1425,8 @@ def default_blas_ldflags():
ret.extend(['-lm', '-lm']) ret.extend(['-lm', '-lm'])
res = try_blas_flag(ret) res = try_blas_flag(ret)
if res: if res:
if 'mkl' in res:
check_mkl_openmp()
return res return res
# If we are using conda and can't reuse numpy blas, then doing # If we are using conda and can't reuse numpy blas, then doing
...@@ -1411,6 +1442,8 @@ def default_blas_ldflags(): ...@@ -1411,6 +1442,8 @@ def default_blas_ldflags():
blas_info.get('library_dirs', [])]) blas_info.get('library_dirs', [])])
res = try_blas_flag(ret) res = try_blas_flag(ret)
if res: if res:
if 'mkl' in res:
check_mkl_openmp()
return res return res
# Add sys.prefix/lib to the runtime search path. On # Add sys.prefix/lib to the runtime search path. On
...@@ -1421,6 +1454,8 @@ def default_blas_ldflags(): ...@@ -1421,6 +1454,8 @@ def default_blas_ldflags():
ret.append('-Wl,-rpath,' + lib_path) ret.append('-Wl,-rpath,' + lib_path)
res = try_blas_flag(ret) res = try_blas_flag(ret)
if res: if res:
if 'mkl' in res:
check_mkl_openmp()
return res return res
except KeyError: except KeyError:
...@@ -1473,6 +1508,11 @@ AddConfigVar('blas.ldflags', ...@@ -1473,6 +1508,11 @@ AddConfigVar('blas.ldflags',
# Added elsewhere in the c key only when needed. # Added elsewhere in the c key only when needed.
in_c_key=False) in_c_key=False)
AddConfigVar('blas.check_openmp',
"Check for openmp library conflict.\nWARNING: Setting this to False leaves you open to wrong results in blas-related operations.",
BoolParam(True),
in_c_key=False)
AddConfigVar( AddConfigVar(
'metaopt.verbose', 'metaopt.verbose',
"0 for silent, 1 for only warnings, 2 for full output with" "0 for silent, 1 for only warnings, 2 for full output with"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论