提交 c649d044 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6147 from abergeron/fix_dnn_mac

Fix dnn loading/usage on mac
...@@ -345,25 +345,16 @@ def default_dnn_include_path(): ...@@ -345,25 +345,16 @@ def default_dnn_include_path():
return os.path.join(cuda_root, 'include') return os.path.join(cuda_root, 'include')
def default_dnn_library_path():
cuda_root = get_cuda_root()
if cuda_root == '':
return ''
if sys.platform == 'darwin':
return os.path.join(cuda_root, 'lib')
if sys.platform == 'win32':
return os.path.join(cuda_root, 'lib', 'x64')
return os.path.join(cuda_root, 'lib64')
AddConfigVar('dnn.include_path', AddConfigVar('dnn.include_path',
"Location of the cudnn header (defaults to the cuda root)", "Location of the cudnn header (defaults to the cuda root)",
# We keep the default here since cudnn needs headers from cuda
StrParam(default_dnn_include_path()), StrParam(default_dnn_include_path()),
# Added elsewhere in the c key only when needed. # Added elsewhere in the c key only when needed.
in_c_key=False) in_c_key=False)
AddConfigVar('dnn.library_path', AddConfigVar('dnn.library_path',
"Location of the cudnn library (defaults to the cuda root)", "Location of the cudnn library (defaults to the cuda root)",
StrParam(default_dnn_library_path()), StrParam(''),
# Added elsewhere in the c key only when needed. # Added elsewhere in the c key only when needed.
in_c_key=False) in_c_key=False)
......
...@@ -57,21 +57,46 @@ try: ...@@ -57,21 +57,46 @@ try:
except ImportError: except ImportError:
pass pass
# Update these names when new versions of cudnn are supported.
WIN32_CUDNN_NAMES = ['cudnn64_6.dll', 'cudnn64_5.dll']
def _load_lib(name):
try:
return ctypes.cdll.LoadLibrary(name)
except OSError:
return None
def _dnn_lib(): def _dnn_lib():
if _dnn_lib.handle is None: if _dnn_lib.handle is None:
import ctypes.util import ctypes.util
if config.dnn.library_path != "":
if sys.platform == 'darwin':
dnn_handle = _load_lib(os.path.join(config.dnn.library_path, 'libcudnn.dylib'))
elif sys.platform == 'win32':
for name in WIN32_CUDNN_NAMES:
dnn_handle = _load_lib(os.path.join(config.dnn.library_path, name))
if dnn_handle is not None:
break
else:
dnn_handle = _load_lib(os.path.join(config.dnn.library_path, 'libcudnn.so'))
else:
lib_name = ctypes.util.find_library('cudnn') lib_name = ctypes.util.find_library('cudnn')
if lib_name is None and sys.platform == 'win32': if lib_name is None and sys.platform == 'win32':
# Update these names when new versions of cudnn are supported. for name in WIN32_CUDNN_NAMES:
for name in ['cudnn64_6.dll', 'cudnn64_5.dll']:
lib_name = ctypes.util.find_library(name) lib_name = ctypes.util.find_library(name)
if lib_name: if lib_name:
break break
if lib_name is None: if lib_name is None:
raise RuntimeError('Could not find cudnn library (looked for v5* or v6*)') raise RuntimeError('Could not find cudnn library (looked for v5* or v6*)')
_dnn_lib.handle = ctypes.cdll.LoadLibrary(lib_name) else:
dnn_handle = ctypes.cdll.LoadLibrary(lib_name)
if dnn_handle is None:
raise RuntimeError('Could not load cudnn library')
_dnn_lib.handle = dnn_handle
cudnn = _dnn_lib.handle cudnn = _dnn_lib.handle
cudnn.cudnnCreate.argtypes = [ctypes.POINTER(ctypes.c_void_p)] cudnn.cudnnCreate.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
cudnn.cudnnCreate.restype = ctypes.c_int cudnn.cudnnCreate.restype = ctypes.c_int
...@@ -217,6 +242,9 @@ class DnnVersion(Op): ...@@ -217,6 +242,9 @@ class DnnVersion(Op):
def c_compile_args(self): def c_compile_args(self):
if config.dnn.library_path: if config.dnn.library_path:
if sys.platform == 'darwin':
return ['-Wl,-rpath,' + config.dnn.library_path]
else:
return ['-Wl,-rpath,"' + config.dnn.library_path + '"'] return ['-Wl,-rpath,"' + config.dnn.library_path + '"']
return [] return []
...@@ -355,6 +383,9 @@ class DnnBase(COp): ...@@ -355,6 +383,9 @@ class DnnBase(COp):
def c_compile_args(self): def c_compile_args(self):
if config.dnn.library_path: if config.dnn.library_path:
if sys.platform == 'darwin':
return ['-Wl,-rpath,' + config.dnn.library_path]
else:
return ['-Wl,-rpath,"' + config.dnn.library_path + '"'] return ['-Wl,-rpath,"' + config.dnn.library_path + '"']
return [] return []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论