提交 89bed35d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add config variable for the cuda include path that is required.

上级 6674b0eb
......@@ -325,13 +325,15 @@ AddConfigVar('dnn.conv.precision',
in_c_key=False)
def get_dnn_path():
def get_cuda_root():
# We look for the cuda path since we need headers from there
v = os.getenv('CUDA_ROOT', "")
if v:
return v
s = os.getenv("PATH")
if not s:
if os.path.exists('/usr/local/cuda'):
return '/usr/local/cuda'
return ''
for dir in s.split(os.path.pathsep):
if os.path.exists(os.path.join(dir, "nvcc")):
......@@ -341,7 +343,7 @@ def get_dnn_path():
AddConfigVar('dnn.base_path',
"Install location of cuDNN.",
StrParam(get_cuda_root),
StrParam(''),
in_c_key=False)
......@@ -357,6 +359,19 @@ AddConfigVar('dnn.include_path',
in_c_key=False)
def default_dnn_cuda_inc_path():
cuda_root = get_cuda_root()
if cuda_root:
return os.path.join(cuda_root, 'include')
return ''
AddConfigVar('dnn.cuda_include_path',
"Location of the cuda include for cudnn",
StrParam(default_dnn_cuda_inc_path),
in_c_key=False)
def default_dnn_lib_path():
if theano.config.dnn.base_path != '':
path = os.path.join(theano.config.dnn.base_path, 'lib')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论