提交 1901d980 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make sure MKL uses GNU OpenMP.

上级 b6e37683
......@@ -1238,6 +1238,29 @@ AddConfigVar('cmodule.debug',
in_c_key=True)
def check_mkl_openmp():
if not theano.config.blas.check_openmp:
return
import os
if ('MKL_THREADING_LAYER' in os.environ and
os.environ['MKL_THREADING_LAYER'] == 'GNU'):
return
try:
import mkl
if '2018' in mkl.get_version_string():
raise RuntimeError('To use MKL 2018 with Theano you MUST set "MKL_THREADING_LAYER=GNU" in your environement.')
except ImportError:
raise RuntimeError("""
Could not import 'mkl'. Either install mkl-service with conda or set
MKL_THREADING_LAYER=GNU in your environment for MKL 2018.
If you have MKL 2017 install and are not in a conda environment you
can set the Theano flag blas.check_openmp to False. Be warned that if
you set this flag and don't set the appropriate environment or make
sure you have the right version you *will* get wrong results.
""")
def default_blas_ldflags():
global numpy
warn_record = []
......@@ -1366,16 +1389,22 @@ def default_blas_ldflags():
flags = []
if lib_path:
flags = ['-L%s' % lib_path[0]]
if '2018' in mkl.get_version_string():
thr = 'mkl_gnu_thread'
else:
thr = 'mkl_intel_thread'
flags += ['-l%s' % l for l in ["mkl_core",
"mkl_intel_thread",
thr,
"mkl_rt"]]
res = try_blas_flag(flags)
if res:
check_mkl_openmp()
return res
flags.extend(['-Wl,-rpath,' + l for l in
blas_info.get('library_dirs', [])])
res = try_blas_flag(flags)
if res:
check_mkl_openmp()
maybe_add_to_os_environ_pathlist('PATH', lib_path[0])
return res
......@@ -1396,6 +1425,8 @@ def default_blas_ldflags():
ret.extend(['-lm', '-lm'])
res = try_blas_flag(ret)
if res:
if 'mkl' in res:
check_mkl_openmp()
return res
# If we are using conda and can't reuse numpy blas, then doing
......@@ -1411,6 +1442,8 @@ def default_blas_ldflags():
blas_info.get('library_dirs', [])])
res = try_blas_flag(ret)
if res:
if 'mkl' in res:
check_mkl_openmp()
return res
# Add sys.prefix/lib to the runtime search path. On
......@@ -1421,6 +1454,8 @@ def default_blas_ldflags():
ret.append('-Wl,-rpath,' + lib_path)
res = try_blas_flag(ret)
if res:
if 'mkl' in res:
check_mkl_openmp()
return res
except KeyError:
......@@ -1473,6 +1508,11 @@ AddConfigVar('blas.ldflags',
# Added elsewhere in the c key only when needed.
in_c_key=False)
AddConfigVar('blas.check_openmp',
"Check for openmp library conflict.\nWARNING: Setting this to False leaves you open to wrong results in blas-related operations.",
BoolParam(True),
in_c_key=False)
AddConfigVar(
'metaopt.verbose',
"0 for silent, 1 for only warnings, 2 for full output with"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论