提交 6d87ce28 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Reintroduce cuda.root along with cuda.include_path, for those ops that

directly link to cuda.
上级 18e8a813
......@@ -232,6 +232,41 @@ AddConfigVar('gpuarray.single_stream',
in_c_key=False)
def get_cuda_root():
# We look for the cuda path since we need headers from there
v = os.getenv('CUDA_ROOT', "")
if v:
return v
v = os.getenv('CUDA_PATH', "")
if v:
return v
s = os.getenv("PATH")
if not s:
return ''
for dir in s.split(os.path.pathsep):
if os.path.exists(os.path.join(dir, "nvcc")):
return os.path.dirname(os.path.abspath(dir))
return ''
AddConfigVar('cuda.root',
"Location of the cuda installation",
StrParam(get_cuda_root),
in_c_key=False)
def default_cuda_include():
if theano.config.cuda.root:
return os.path.join(theano.config.cuda.root, 'include')
return ''
AddConfigVar('cuda.include_path',
"Location of the cuda includes",
StrParam(default_cuda_include),
in_c_key=False)
def safe_no_dnn_workmem(workmem):
"""
Make sure the user is not attempting to use dnn.conv.workmem`.
......@@ -325,22 +360,6 @@ AddConfigVar('dnn.conv.precision',
in_c_key=False)
def get_cuda_root():
# We look for the cuda path since we need headers from there
v = os.getenv('CUDA_ROOT', "")
if v:
return v
s = os.getenv("PATH")
if not s:
if os.path.exists('/usr/local/cuda'):
return '/usr/local/cuda'
return ''
for dir in s.split(os.path.pathsep):
if os.path.exists(os.path.join(dir, "nvcc")):
return os.path.dirname(os.path.abspath(dir))
return ''
AddConfigVar('dnn.base_path',
"Install location of cuDNN.",
StrParam(''),
......@@ -359,19 +378,6 @@ AddConfigVar('dnn.include_path',
in_c_key=False)
def default_dnn_cuda_inc_path():
cuda_root = get_cuda_root()
if cuda_root:
return os.path.join(cuda_root, 'include')
return ''
AddConfigVar('dnn.cuda_include_path',
"Location of the cuda include for cudnn",
StrParam(default_dnn_cuda_inc_path),
in_c_key=False)
def default_dnn_lib_path():
if theano.config.dnn.base_path != '':
path = os.path.join(theano.config.dnn.base_path, 'lib')
......
......@@ -232,8 +232,8 @@ class DnnVersion(Op):
res = []
if config.dnn.include_path:
res.append(config.dnn.include_path)
if config.dnn.cuda_include_path:
res.append(config.dnn.cuda_include_path)
if config.cuda.include_path:
res.append(config.cuda.include_path)
return res
def c_libraries(self):
......@@ -375,8 +375,8 @@ class DnnBase(COp):
dirs = [os.path.dirname(__file__), pygpu.get_include()]
if config.dnn.include_path:
dirs.append(config.dnn.include_path)
if config.dnn.cuda_include_path:
dirs.append(config.dnn.cuda_include_path)
if config.cuda.include_path:
dirs.append(config.cuda.include_path)
return dirs
def c_libraries(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论