提交 d00f619d authored 作者: James Bergstra's avatar James Bergstra

added try-except block to cuda.__init__ for when compilation and loading of cuda_ndarray fails

上级 ae613167
...@@ -5,7 +5,7 @@ from theano import config ...@@ -5,7 +5,7 @@ from theano import config
import logging, copy import logging, copy
_logger_name = 'theano.sandbox.cuda' _logger_name = 'theano.sandbox.cuda'
_logger = logging.getLogger(_logger_name) _logger = logging.getLogger(_logger_name)
_logger.setLevel(logging.INFO) _logger.setLevel(logging.WARNING)
_logger.addHandler(logging.StreamHandler()) _logger.addHandler(logging.StreamHandler())
def warning(*msg): def warning(*msg):
_logger.warning(_logger_name+'WARNING: '+' '.join(str(m) for m in msg)) _logger.warning(_logger_name+'WARNING: '+' '.join(str(m) for m in msg))
...@@ -63,7 +63,8 @@ if not compile_cuda_ndarray: ...@@ -63,7 +63,8 @@ if not compile_cuda_ndarray:
except ImportError: except ImportError:
compile_cuda_ndarray = True compile_cuda_ndarray = True
if compile_cuda_ndarray: try:
if compile_cuda_ndarray:
import nvcc_compiler import nvcc_compiler
if not nvcc_compiler.is_nvcc_available(): if not nvcc_compiler.is_nvcc_available():
set_cuda_disabled() set_cuda_disabled()
...@@ -78,6 +79,9 @@ if compile_cuda_ndarray: ...@@ -78,6 +79,9 @@ if compile_cuda_ndarray:
include_dirs=[cuda_path], libs=['cublas']) include_dirs=[cuda_path], libs=['cublas'])
from cuda_ndarray.cuda_ndarray import * from cuda_ndarray.cuda_ndarray import *
except Exception, e:
_logger.error( "Failed to compile cuda_ndarray.cu: %s" % str(e))
set_cuda_disabled()
if enable_cuda: if enable_cuda:
#check if their is an old cuda_ndarray that was loading instead of the one we compiled! #check if their is an old cuda_ndarray that was loading instead of the one we compiled!
...@@ -139,6 +143,5 @@ def handle_shared_float32(tf): ...@@ -139,6 +143,5 @@ def handle_shared_float32(tf):
else: else:
raise NotImplementedError('removing our handler') raise NotImplementedError('removing our handler')
if enable_cuda and config.device.startswith('gpu'): if enable_cuda and config.device.startswith('gpu'):
use() use()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论