提交 88e72e8d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use the new dnn variables appropriately.

上级 89bed35d
...@@ -72,16 +72,16 @@ def _dnn_lib(): ...@@ -72,16 +72,16 @@ def _dnn_lib():
if _dnn_lib.handle is None: if _dnn_lib.handle is None:
import ctypes.util import ctypes.util
if config.dnn.library_path != "": if config.dnn.bin_path != "":
if sys.platform == 'darwin': if sys.platform == 'darwin':
dnn_handle = _load_lib(os.path.join(config.dnn.library_path, 'libcudnn.dylib')) dnn_handle = _load_lib(os.path.join(config.dnn.bin_path, 'libcudnn.dylib'))
elif sys.platform == 'win32': elif sys.platform == 'win32':
for name in WIN32_CUDNN_NAMES: for name in WIN32_CUDNN_NAMES:
dnn_handle = _load_lib(os.path.join(config.dnn.library_path, name)) dnn_handle = _load_lib(os.path.join(config.dnn.bin_path, name))
if dnn_handle is not None: if dnn_handle is not None:
break break
else: else:
dnn_handle = _load_lib(os.path.join(config.dnn.library_path, 'libcudnn.so')) dnn_handle = _load_lib(os.path.join(config.dnn.bin_path, 'libcudnn.so'))
else: else:
lib_name = ctypes.util.find_library('cudnn') lib_name = ctypes.util.find_library('cudnn')
if lib_name is None and sys.platform == 'win32': if lib_name is None and sys.platform == 'win32':
...@@ -135,7 +135,6 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) { ...@@ -135,7 +135,6 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
} }
""" """
params = ["-l", "cudnn", "-I" + os.path.dirname(__file__)]
path_wrapper = "\"" if os.name == 'nt' else "" path_wrapper = "\"" if os.name == 'nt' else ""
params = ["-l", "cudnn"] params = ["-l", "cudnn"]
params.extend(['-I%s%s%s' % (path_wrapper, os.path.dirname(__file__), path_wrapper)]) params.extend(['-I%s%s%s' % (path_wrapper, os.path.dirname(__file__), path_wrapper)])
...@@ -230,7 +229,12 @@ class DnnVersion(Op): ...@@ -230,7 +229,12 @@ class DnnVersion(Op):
return ['cudnn.h'] return ['cudnn.h']
def c_header_dirs(self): def c_header_dirs(self):
return [config.dnn.include_path] if config.dnn.include_path else [] res = []
if config.dnn.include_path:
res.append(config.dnn.include_path)
if config.dnn.cuda_include_path:
res.append(config.dnn.cuda_include_path)
return res
def c_libraries(self): def c_libraries(self):
return ['cudnn'] return ['cudnn']
...@@ -241,11 +245,11 @@ class DnnVersion(Op): ...@@ -241,11 +245,11 @@ class DnnVersion(Op):
return [] return []
def c_compile_args(self): def c_compile_args(self):
if config.dnn.library_path: if config.dnn.bin_path:
if sys.platform == 'darwin': if sys.platform == 'darwin':
return ['-Wl,-rpath,' + config.dnn.library_path] return ['-Wl,-rpath,' + config.dnn.bin_path]
else: else:
return ['-Wl,-rpath,"' + config.dnn.library_path + '"'] return ['-Wl,-rpath,"' + config.dnn.bin_path + '"']
return [] return []
def c_support_code(self): def c_support_code(self):
...@@ -371,6 +375,8 @@ class DnnBase(COp): ...@@ -371,6 +375,8 @@ class DnnBase(COp):
dirs = [os.path.dirname(__file__), pygpu.get_include()] dirs = [os.path.dirname(__file__), pygpu.get_include()]
if config.dnn.include_path: if config.dnn.include_path:
dirs.append(config.dnn.include_path) dirs.append(config.dnn.include_path)
if config.dnn.cuda_include_path:
dirs.append(config.dnn.cuda_include_path)
return dirs return dirs
def c_libraries(self): def c_libraries(self):
...@@ -382,11 +388,11 @@ class DnnBase(COp): ...@@ -382,11 +388,11 @@ class DnnBase(COp):
return [] return []
def c_compile_args(self): def c_compile_args(self):
if config.dnn.library_path: if config.dnn.bin_path:
if sys.platform == 'darwin': if sys.platform == 'darwin':
return ['-Wl,-rpath,' + config.dnn.library_path] return ['-Wl,-rpath,' + config.dnn.bin_path]
else: else:
return ['-Wl,-rpath,"' + config.dnn.library_path + '"'] return ['-Wl,-rpath,"' + config.dnn.bin_path + '"']
return [] return []
def c_code_cache_version(self): def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论