提交 c649d044 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6147 from abergeron/fix_dnn_mac

Fix dnn loading/usage on mac
...@@ -345,25 +345,16 @@ def default_dnn_include_path(): ...@@ -345,25 +345,16 @@ def default_dnn_include_path():
return os.path.join(cuda_root, 'include') return os.path.join(cuda_root, 'include')
def default_dnn_library_path():
cuda_root = get_cuda_root()
if cuda_root == '':
return ''
if sys.platform == 'darwin':
return os.path.join(cuda_root, 'lib')
if sys.platform == 'win32':
return os.path.join(cuda_root, 'lib', 'x64')
return os.path.join(cuda_root, 'lib64')
AddConfigVar('dnn.include_path', AddConfigVar('dnn.include_path',
"Location of the cudnn header (defaults to the cuda root)", "Location of the cudnn header (defaults to the cuda root)",
# We keep the default here since cudnn needs headers from cuda
StrParam(default_dnn_include_path()), StrParam(default_dnn_include_path()),
# Added elsewhere in the c key only when needed. # Added elsewhere in the c key only when needed.
in_c_key=False) in_c_key=False)
AddConfigVar('dnn.library_path', AddConfigVar('dnn.library_path',
"Location of the cudnn library (defaults to the cuda root)", "Location of the cudnn library (defaults to the cuda root)",
StrParam(default_dnn_library_path()), StrParam(''),
# Added elsewhere in the c key only when needed. # Added elsewhere in the c key only when needed.
in_c_key=False) in_c_key=False)
......
...@@ -57,21 +57,46 @@ try: ...@@ -57,21 +57,46 @@ try:
except ImportError: except ImportError:
pass pass
# Update these names when new versions of cudnn are supported.
WIN32_CUDNN_NAMES = ['cudnn64_6.dll', 'cudnn64_5.dll']
def _load_lib(name):
try:
return ctypes.cdll.LoadLibrary(name)
except OSError:
return None
def _dnn_lib(): def _dnn_lib():
if _dnn_lib.handle is None: if _dnn_lib.handle is None:
import ctypes.util import ctypes.util
lib_name = ctypes.util.find_library('cudnn') if config.dnn.library_path != "":
if lib_name is None and sys.platform == 'win32': if sys.platform == 'darwin':
# Update these names when new versions of cudnn are supported. dnn_handle = _load_lib(os.path.join(config.dnn.library_path, 'libcudnn.dylib'))
for name in ['cudnn64_6.dll', 'cudnn64_5.dll']: elif sys.platform == 'win32':
lib_name = ctypes.util.find_library(name) for name in WIN32_CUDNN_NAMES:
if lib_name: dnn_handle = _load_lib(os.path.join(config.dnn.library_path, name))
break if dnn_handle is not None:
if lib_name is None: break
raise RuntimeError('Could not find cudnn library (looked for v5* or v6*)') else:
_dnn_lib.handle = ctypes.cdll.LoadLibrary(lib_name) dnn_handle = _load_lib(os.path.join(config.dnn.library_path, 'libcudnn.so'))
else:
lib_name = ctypes.util.find_library('cudnn')
if lib_name is None and sys.platform == 'win32':
for name in WIN32_CUDNN_NAMES:
lib_name = ctypes.util.find_library(name)
if lib_name:
break
if lib_name is None:
raise RuntimeError('Could not find cudnn library (looked for v5* or v6*)')
else:
dnn_handle = ctypes.cdll.LoadLibrary(lib_name)
if dnn_handle is None:
raise RuntimeError('Could not load cudnn library')
_dnn_lib.handle = dnn_handle
cudnn = _dnn_lib.handle cudnn = _dnn_lib.handle
cudnn.cudnnCreate.argtypes = [ctypes.POINTER(ctypes.c_void_p)] cudnn.cudnnCreate.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
cudnn.cudnnCreate.restype = ctypes.c_int cudnn.cudnnCreate.restype = ctypes.c_int
...@@ -217,7 +242,10 @@ class DnnVersion(Op): ...@@ -217,7 +242,10 @@ class DnnVersion(Op):
def c_compile_args(self): def c_compile_args(self):
if config.dnn.library_path: if config.dnn.library_path:
return ['-Wl,-rpath,"' + config.dnn.library_path + '"'] if sys.platform == 'darwin':
return ['-Wl,-rpath,' + config.dnn.library_path]
else:
return ['-Wl,-rpath,"' + config.dnn.library_path + '"']
return [] return []
def c_support_code(self): def c_support_code(self):
...@@ -355,7 +383,10 @@ class DnnBase(COp): ...@@ -355,7 +383,10 @@ class DnnBase(COp):
def c_compile_args(self): def c_compile_args(self):
if config.dnn.library_path: if config.dnn.library_path:
return ['-Wl,-rpath,"' + config.dnn.library_path + '"'] if sys.platform == 'darwin':
return ['-Wl,-rpath,' + config.dnn.library_path]
else:
return ['-Wl,-rpath,"' + config.dnn.library_path + '"']
return [] return []
def c_code_cache_version(self): def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论