提交 6674b0eb authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Rework the config variables for cudnn to take into account the distinction…

Rework the config variables for cudnn to take into account the distinction between load and link libraries on windows.
上级 f9dc9ff6
...@@ -325,7 +325,8 @@ AddConfigVar('dnn.conv.precision', ...@@ -325,7 +325,8 @@ AddConfigVar('dnn.conv.precision',
in_c_key=False) in_c_key=False)
def get_cuda_root(): def get_dnn_path():
# We look for the cuda path since we need headers from there
v = os.getenv('CUDA_ROOT', "") v = os.getenv('CUDA_ROOT', "")
if v: if v:
return v return v
...@@ -338,26 +339,59 @@ def get_cuda_root(): ...@@ -338,26 +339,59 @@ def get_cuda_root():
return '' return ''
def default_dnn_include_path(): AddConfigVar('dnn.base_path',
cuda_root = get_cuda_root() "Install location of cuDNN.",
if cuda_root == '': StrParam(get_cuda_root),
in_c_key=False)
def default_dnn_inc_path():
if theano.config.dnn.base_path != '':
return os.path.join(theano.config.dnn.base_path, 'include')
return '' return ''
return os.path.join(cuda_root, 'include')
AddConfigVar('dnn.include_path', AddConfigVar('dnn.include_path',
"Location of the cudnn header (defaults to the cuda root)", "Location of the cudnn header",
# We keep the default here since cudnn needs headers from cuda StrParam(default_dnn_inc_path),
StrParam(default_dnn_include_path()),
# Added elsewhere in the c key only when needed.
in_c_key=False) in_c_key=False)
def default_dnn_lib_path():
if theano.config.dnn.base_path != '':
path = os.path.join(theano.config.dnn.base_path, 'lib')
if sys.platform == 'win32':
path = os.path.join(path, 'x64')
return path
return ''
AddConfigVar('dnn.library_path', AddConfigVar('dnn.library_path',
"Location of the cudnn library (defaults to the cuda root)", "Location of the cudnn link library.",
StrParam(''), StrParam(default_dnn_lib_path),
# Added elsewhere in the c key only when needed.
in_c_key=False) in_c_key=False)
def default_dnn_bin_path():
if type(theano.config.dnn).base_path.is_default:
return ''
else:
if theano.config.dnn.base_path != '':
if sys.platform == 'win32':
return os.path.join(theano.config.dnn.base_path, 'bin')
else:
return theano.config.dnn.library_path
return ''
AddConfigVar('dnn.bin_path',
"Location of the cuDNN load library "
"(on non-windows platforms, "
"this is the same as dnn.library_path)",
StrParam(default_dnn_bin_path),
in_c_key=False)
AddConfigVar('dnn.enabled', AddConfigVar('dnn.enabled',
"'auto', use cuDNN if available, but silently fall back" "'auto', use cuDNN if available, but silently fall back"
" to not using it if not present." " to not using it if not present."
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论