提交 1e500bd6 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Move the cuda and dnn flags in configdefaults since the later depends on the former.

Otherwise it is possible to get into a situation where we use the dnn flags, but the cuda ones aren't defined and stuff breaks.
上级 201b4610
......@@ -118,6 +118,104 @@ AddConfigVar(
in_c_key=False)
def default_cuda_root():
v = os.getenv('CUDA_ROOT', "")
if v:
return v
s = os.getenv("PATH")
if not s:
return ''
for dir in s.split(os.path.pathsep):
if os.path.exists(os.path.join(dir, "nvcc")):
return os.path.split(dir)[0]
return ''
AddConfigVar(
'cuda.root',
"""directory with bin/, lib/, include/ for cuda utilities.
This directory is included via -L and -rpath when linking
dynamically compiled modules. If AUTO and nvcc is in the
path, it will use one of nvcc parent directory. Otherwise
/usr/local/cuda will be used. Leave empty to prevent extra
linker directives. Default: environment variable "CUDA_ROOT"
or else "AUTO".
""",
StrParam(default_cuda_root),
in_c_key=False)
def filter_nvcc_flags(s):
assert isinstance(s, str)
flags = [flag for flag in s.split(' ') if flag]
if any([f for f in flags if not f.startswith("-")]):
raise ValueError(
"Theano nvcc.flags support only parameter/value pairs without"
" space between them. e.g.: '--machine 64' is not supported,"
" but '--machine=64' is supported. Please add the '=' symbol."
" nvcc.flags value is '%s'" % s)
return ' '.join(flags)
AddConfigVar('nvcc.flags',
"Extra compiler flags for nvcc",
ConfigParam("", filter_nvcc_flags),
# Not needed in c key as it is already added.
# We remove it as we don't make the md5 of config to change
# if theano.sandbox.cuda is loaded or not.
in_c_key=False)
AddConfigVar('nvcc.compiler_bindir',
"If defined, nvcc compiler driver will seek g++ and gcc"
" in this directory",
StrParam(""),
in_c_key=False)
AddConfigVar('nvcc.fastmath',
"",
BoolParam(False),
# Not needed in c key as it is already added.
# We remove it as we don't make the md5 of config to change
# if theano.sandbox.cuda is loaded or not.
in_c_key=False)
AddConfigVar('dnn.conv.workmem',
"This flag is deprecated; use dnn.conv.algo_fwd.",
EnumStr(''),
in_c_key=False)
AddConfigVar('dnn.conv.workmem_bwd',
"This flag is deprecated; use dnn.conv.algo_bwd.",
EnumStr(''),
in_c_key=False)
AddConfigVar('dnn.conv.algo_fwd',
"Default implementation to use for CuDNN forward convolution.",
EnumStr('small', 'none', 'large', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change'),
in_c_key=False)
AddConfigVar('dnn.conv.algo_bwd',
"Default implementation to use for CuDNN backward convolution.",
EnumStr('none', 'deterministic', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change'),
in_c_key=False)
def default_dnn_path(suffix):
def f(suffix=suffix):
if config.cuda.root == '':
return ''
return os.path.join(config.cuda.root, suffix)
return f
AddConfigVar('dnn.include_path',
"Location of the cudnn header (defaults to the cuda root)",
StrParam(default_dnn_path('include')))
AddConfigVar('dnn.library_path',
"Location of the cudnn header (defaults to the cuda root)",
StrParam(default_dnn_path('lib64')))
# This flag determines whether or not to raise error/warning message if
# there is a CPU Op in the computational graph.
AddConfigVar(
......
......@@ -27,8 +27,6 @@ from theano.sandbox.cuda import gpu_seqopt, register_opt
from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler
import theano.sandbox.dnn_flags
def dnn_available():
if dnn_available.avail is None:
......
......@@ -8,6 +8,7 @@ import warnings
import numpy
from theano import config
from theano.compat import decode, decode_iter
from theano.gof import local_bitwidth
from theano.gof.utils import hash_from_file
......@@ -19,67 +20,6 @@ from theano.misc.windows import output_subprocess_Popen
_logger = logging.getLogger("theano.sandbox.cuda.nvcc_compiler")
from theano.configparser import (config, AddConfigVar, StrParam,
BoolParam, ConfigParam)
AddConfigVar('nvcc.compiler_bindir',
"If defined, nvcc compiler driver will seek g++ and gcc"
" in this directory",
StrParam(""),
in_c_key=False)
user_provided_cuda_root = True
def default_cuda_root():
global user_provided_cuda_root
v = os.getenv('CUDA_ROOT', "")
user_provided_cuda_root = False
if v:
return v
return find_cuda_root()
AddConfigVar('cuda.root',
"""directory with bin/, lib/, include/ for cuda utilities.
This directory is included via -L and -rpath when linking
dynamically compiled modules. If AUTO and nvcc is in the
path, it will use one of nvcc parent directory. Otherwise
/usr/local/cuda will be used. Leave empty to prevent extra
linker directives. Default: environment variable "CUDA_ROOT"
or else "AUTO".
""",
StrParam(default_cuda_root),
in_c_key=False)
def filter_nvcc_flags(s):
assert isinstance(s, str)
flags = [flag for flag in s.split(' ') if flag]
if any([f for f in flags if not f.startswith("-")]):
raise ValueError(
"Theano nvcc.flags support only parameter/value pairs without"
" space between them. e.g.: '--machine 64' is not supported,"
" but '--machine=64' is supported. Please add the '=' symbol."
" nvcc.flags value is '%s'" % s)
return ' '.join(flags)
AddConfigVar('nvcc.flags',
"Extra compiler flags for nvcc",
ConfigParam("", filter_nvcc_flags),
# Not needed in c key as it is already added.
# We remove it as we don't make the md5 of config to change
# if theano.sandbox.cuda is loaded or not.
in_c_key=False)
AddConfigVar('nvcc.fastmath',
"",
BoolParam(False),
# Not needed in c key as it is already added.
# We remove it as we don't make the md5 of config to change
# if theano.sandbox.cuda is loaded or not.
in_c_key=False)
nvcc_path = 'nvcc'
nvcc_version = None
......@@ -115,14 +55,6 @@ def is_nvcc_available():
return False
def find_cuda_root():
s = os.getenv("PATH")
if not s:
return
for dir in s.split(os.path.pathsep):
if os.path.exists(os.path.join(dir, "nvcc")):
return os.path.split(dir)[0]
rpath_defaults = []
......@@ -227,7 +159,7 @@ class NVCC_compiler(Compiler):
include_dirs
A list of include directory names (each gets prefixed with -I).
lib_dirs
A list of library search path directory names (each gets
A list of library search path directory names (each gets
prefixed with -L).
libs
A list of libraries to link with (each gets prefixed with -l).
......@@ -357,7 +289,8 @@ class NVCC_compiler(Compiler):
# provided an cuda.root flag, we need to add one, but
# otherwise, we don't add it. See gh-1540 and
# https://wiki.debian.org/RpathIssue for details.
if (user_provided_cuda_root and
if (not type(config.cuda).root.is_default and
os.path.exists(os.path.join(config.cuda.root, 'lib'))):
rpaths.append(os.path.join(config.cuda.root, 'lib'))
......
"""
This module contains the configuration flags for cudnn support.
Those are shared between the cuda and gpuarray backend which is why
they are in this file.
"""
import os.path
from theano.configparser import AddConfigVar, EnumStr, StrParam
from theano import config
AddConfigVar('dnn.conv.workmem',
"This flag is deprecated; use dnn.conv.algo_fwd.",
EnumStr(''),
in_c_key=False)
AddConfigVar('dnn.conv.workmem_bwd',
"This flag is deprecated; use dnn.conv.algo_bwd.",
EnumStr(''),
in_c_key=False)
AddConfigVar('dnn.conv.algo_fwd',
"Default implementation to use for CuDNN forward convolution.",
EnumStr('small', 'none', 'large', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change'),
in_c_key=False)
AddConfigVar('dnn.conv.algo_bwd',
"Default implementation to use for CuDNN backward convolution.",
EnumStr('none', 'deterministic', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change'),
in_c_key=False)
AddConfigVar('dnn.include_path',
"Location of the cudnn header (defaults to the cuda root)",
StrParam(lambda: os.path.join(config.cuda.root, 'include')))
AddConfigVar('dnn.library_path',
"Location of the cudnn header (defaults to the cuda root)",
StrParam(lambda: os.path.join(config.cuda.root, 'lib64')))
......@@ -28,8 +28,6 @@ from .nnet import GpuSoftmax
from .opt import gpu_seqopt, register_opt, conv_groupopt, op_lifter
from .opt_util import alpha_merge, output_merge
# We need to import this to define the flags.
from theano.sandbox import dnn_flags # noqa
def dnn_available():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论