提交 2a228ccf authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Importing cuda without nvcc will now display a warning instead of raising an exception

上级 f35f6ba9
import os, sys
from theano.gof.compiledir import get_compiledir
from theano.compile import optdb
import theano.config as config
import logging, copy
_logger_name = 'theano_cuda_ndarray'
......@@ -15,8 +16,34 @@ def debug(*msg):
_logger.debug(_logger_name+'DEBUG: '+' '.join(str(m) for m in msg))
#compile type_support.cu
#this need that nvcc(part of cuda) is installed
# Compile type_support.cu
# This need that nvcc (part of cuda) is installed. If it is not, a warning is
# printed and this module will not be working properly (we set `enable_cuda`
# to False).
# This variable is True by default, and set to False if something goes wrong
# when trying to initialize cuda.
enable_cuda = True
# Global variable to avoid displaying the same warning multiple times.
cuda_warning_is_displayed = False
# Code factorized within a function so that it may be called from multiple
# places (which is not currently the case, but may be useful in the future).
def set_cuda_disabled():
"""Function used to disable cuda.
A warning is displayed, so that the user is aware that cuda-based code is
not going to work.
Note that there is no point calling this function from outside of
`cuda.__init__`, since it has no effect once the module is loaded.
"""
global enable_cuda, cuda_warning_is_displayed
enable_cuda = False
if not cuda_warning_is_displayed:
cuda_warning_is_displayed = True
warning('Cuda is disabled, cuda-based code will thus not be '
'working properly')
old_file = os.path.join(os.path.split(__file__)[0],'type_support.so')
if os.path.exists(old_file):
......@@ -30,6 +57,10 @@ except ImportError:
import nvcc_compiler
if not nvcc_compiler.is_nvcc_available():
set_cuda_disabled()
if enable_cuda:
print __file__
cuda_path=os.path.split(old_file)[0]
......@@ -64,21 +95,20 @@ except ImportError:
from type_support.type_support import *
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.var import (CudaNdarrayVariable,
if enable_cuda:
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.var import (CudaNdarrayVariable,
CudaNdarrayConstant,
CudaNdarraySharedVariable,
shared_constructor)
import basic_ops
from basic_ops import (GpuFromHost, HostFromGpu, GpuElemwise,
import basic_ops
from basic_ops import (GpuFromHost, HostFromGpu, GpuElemwise,
GpuDimShuffle, GpuSum, GpuReshape,
GpuSubtensor, GpuIncSubtensor, GpuFlatten, GpuShape)
import opt
import cuda_ndarray
import opt
import cuda_ndarray
import theano.config as config
def use(device=config.THEANO_GPU):
if use.device_number is None:
......
......@@ -19,6 +19,15 @@ def debug(*args):
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger.debug("DEBUG: "+' '.join(str(a) for a in args))
def is_nvcc_available():
"""Return True iff the nvcc compiler is found."""
try:
subprocess.call(['nvcc', '--version'], stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
return True
except:
return False
def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
preargs=[]):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论