提交 b4ff2ff2 authored 作者: Matt Graham's avatar Matt Graham

Addition of user specified cuDNN paths to Op compile commands.

上级 6b02f8ca
...@@ -355,9 +355,15 @@ class DnnVersion(GpuOp): ...@@ -355,9 +355,15 @@ class DnnVersion(GpuOp):
def c_headers(self): def c_headers(self):
return ['cudnn.h'] return ['cudnn.h']
def c_header_dirs(self):
return [config.dnn.include_path]
def c_libraries(self): def c_libraries(self):
return ['cudnn'] return ['cudnn']
def c_lib_dirs(self):
return [config.dnn.library_path]
def c_support_code(self): def c_support_code(self):
return """ return """
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
......
...@@ -86,11 +86,14 @@ class DnnBase(GpuOp, COp): ...@@ -86,11 +86,14 @@ class DnnBase(GpuOp, COp):
return ['cudnn.h', 'cudnn_helper.h'] return ['cudnn.h', 'cudnn_helper.h']
def c_header_dirs(self): def c_header_dirs(self):
return [os.path.dirname(__file__)] return [os.path.dirname(__file__), config.dnn.include_path]
def c_libraries(self): def c_libraries(self):
return ['cudnn'] return ['cudnn']
def c_lib_dirs(self):
return [config.dnn.library_path]
class GpuDnnConvDesc(GpuOp): class GpuDnnConvDesc(GpuOp):
""" """
...@@ -107,11 +110,14 @@ class GpuDnnConvDesc(GpuOp): ...@@ -107,11 +110,14 @@ class GpuDnnConvDesc(GpuOp):
return ['cudnn.h', 'cudnn_helper.h'] return ['cudnn.h', 'cudnn_helper.h']
def c_header_dirs(self): def c_header_dirs(self):
return [os.path.dirname(__file__)] return [os.path.dirname(__file__), config.dnn.include_path]
def c_libraries(self): def c_libraries(self):
return ['cudnn'] return ['cudnn']
def c_lib_dirs(self):
return [config.dnn.library_path]
def c_compiler(self): def c_compiler(self):
return NVCC_compiler return NVCC_compiler
...@@ -1256,11 +1262,14 @@ class GpuDnnPoolDesc(GpuOp): ...@@ -1256,11 +1262,14 @@ class GpuDnnPoolDesc(GpuOp):
return ['cudnn.h', 'cudnn_helper.h'] return ['cudnn.h', 'cudnn_helper.h']
def c_header_dirs(self): def c_header_dirs(self):
return [os.path.dirname(__file__)] return [os.path.dirname(__file__), config.dnn.include_path]
def c_libraries(self): def c_libraries(self):
return ['cudnn'] return ['cudnn']
def c_lib_dirs(self):
return [config.dnn.library_path]
def c_compiler(self): def c_compiler(self):
return NVCC_compiler return NVCC_compiler
......
...@@ -218,7 +218,7 @@ class NVCC_compiler(Compiler): ...@@ -218,7 +218,7 @@ class NVCC_compiler(Compiler):
if 'cudart' not in libs: if 'cudart' not in libs:
libs.append('cudart') libs.append('cudart')
lib_dirs = std_lib_dirs() + lib_dirs lib_dirs = lib_dirs + std_lib_dirs()
if any(ld == os.path.join(cuda_root, 'lib') or if any(ld == os.path.join(cuda_root, 'lib') or
ld == os.path.join(cuda_root, 'lib64') for ld in lib_dirs): ld == os.path.join(cuda_root, 'lib64') for ld in lib_dirs):
warnings.warn("You have the cuda library directory in your " warnings.warn("You have the cuda library directory in your "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论