提交 ae47cc39 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2012 from abergeron/fix_cublas_init

Only create the handle if we specify a device to gpu_init().
...@@ -3043,7 +3043,7 @@ CudaNdarray_ptr_int_size(PyObject* _unused, PyObject* args) ...@@ -3043,7 +3043,7 @@ CudaNdarray_ptr_int_size(PyObject* _unused, PyObject* args)
} }
static int cublas_init(); static int cublas_init();
static int cublas_shutdown(); static void cublas_shutdown();
// Initialize the gpu. // Initialize the gpu.
// Takes one optional parameter, the device number. // Takes one optional parameter, the device number.
// If provided, it sets that device to be the active device. // If provided, it sets that device to be the active device.
...@@ -3100,11 +3100,6 @@ CudaNdarray_gpu_init(PyObject* _unused, PyObject* args) ...@@ -3100,11 +3100,6 @@ CudaNdarray_gpu_init(PyObject* _unused, PyObject* args)
"There is no device that supports CUDA"); "There is no device that supports CUDA");
} }
// Initialize cublas
if (handle != NULL)
if (cublas_shutdown() == -1)
return NULL;
if(card_number_provided) { if(card_number_provided) {
err = cudaSetDevice(card_nb); err = cudaSetDevice(card_nb);
if(cudaSuccess != err) { if(cudaSuccess != err) {
...@@ -3113,11 +3108,10 @@ CudaNdarray_gpu_init(PyObject* _unused, PyObject* args) ...@@ -3113,11 +3108,10 @@ CudaNdarray_gpu_init(PyObject* _unused, PyObject* args)
card_nb, card_nb,
cudaGetErrorString(cudaGetLastError())); cudaGetErrorString(cudaGetLastError()));
} }
if (cublas_init() == -1)
return NULL;
} }
if (cublas_init() == -1)
return NULL;
Py_INCREF(Py_None); Py_INCREF(Py_None);
return Py_None; return Py_None;
} }
...@@ -3145,6 +3139,8 @@ CudaNdarray_active_device_name(PyObject* _unused, PyObject* _unused_args) { ...@@ -3145,6 +3139,8 @@ CudaNdarray_active_device_name(PyObject* _unused, PyObject* _unused_args) {
PyObject * PyObject *
CudaNdarray_gpu_shutdown(PyObject* _unused, PyObject* _unused_args) { CudaNdarray_gpu_shutdown(PyObject* _unused, PyObject* _unused_args) {
// Don't handle errors here
cublas_shutdown();
cudaThreadExit(); cudaThreadExit();
g_gpu_context_active = 0; // context has now been closed down g_gpu_context_active = 0; // context has now been closed down
Py_INCREF(Py_None); Py_INCREF(Py_None);
...@@ -3595,20 +3591,13 @@ cublas_init() ...@@ -3595,20 +3591,13 @@ cublas_init()
return 0; return 0;
} }
static int static void
cublas_shutdown() cublas_shutdown()
{ {
cublasStatus_t err; if (handle != NULL)
err = cublasDestroy(handle); cublasDestroy(handle);
if (CUBLAS_STATUS_SUCCESS != err) // No point in handling any errors here
{
PyErr_SetString(PyExc_RuntimeError,
"cublas_init tried to destroy the old cublas"
" context, cublasDestroy() returned an error.");
return -1;
}
handle = NULL; handle = NULL;
return 0;
} }
int int
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论