提交 d8543295 authored 作者: Frederic's avatar Frederic

Fix copy to/from the gpu of size bigger then 2g

上级 6686a051
...@@ -625,17 +625,20 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args) ...@@ -625,17 +625,20 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args)
npy_intp rval_size = PyArray_SIZE(rval); npy_intp rval_size = PyArray_SIZE(rval);
void *rval_data = PyArray_DATA(rval); void *rval_data = PyArray_DATA(rval);
cublasStatus_t err; cudaError_t err;
CNDA_BEGIN_ALLOW_THREADS CNDA_BEGIN_ALLOW_THREADS;
err = cublasGetVector(rval_size, sizeof(real),
contiguous_self->devdata, 1,
rval_data, 1);
//CNDA_THREAD_SYNC; // unneeded because cublasGetVector is blocking anyway
CNDA_END_ALLOW_THREADS
if (CUBLAS_STATUS_SUCCESS != err) err = cudaMemcpy(rval_data, contiguous_self->devdata,
rval_size * sizeof(real),
cudaMemcpyDeviceToHost
);
//CNDA_THREAD_SYNC; // unneeded because cudaMemcpy is blocking anyway
CNDA_END_ALLOW_THREADS;
if (cudaSuccess != err)
{ {
PyErr_SetString(PyExc_RuntimeError, "error copying data to host"); PyErr_Format(PyExc_RuntimeError, "error (%s)copying data to host",
cudaGetErrorString(err));
Py_DECREF(rval); Py_DECREF(rval);
rval = NULL; rval = NULL;
} }
...@@ -3754,20 +3757,19 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj) ...@@ -3754,20 +3757,19 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
} }
npy_intp py_src_size = PyArray_SIZE(py_src); npy_intp py_src_size = PyArray_SIZE(py_src);
void *py_src_data = PyArray_DATA(py_src); void *py_src_data = PyArray_DATA(py_src);
cublasStatus_t cerr; cudaError_t cerr;
CNDA_BEGIN_ALLOW_THREADS CNDA_BEGIN_ALLOW_THREADS;
cerr = cublasSetVector(py_src_size, cerr = cudaMemcpy(self->devdata, py_src_data,
sizeof(real), py_src_size * sizeof(real),
py_src_data, 1, cudaMemcpyHostToDevice);
self->devdata, 1); //CNDA_THREAD_SYNC; // unneeded because cudaMemcpy is blocking anyway
//CNDA_THREAD_SYNC; // unneeded because cublasSetVector is blocking anyway CNDA_END_ALLOW_THREADS;
CNDA_END_ALLOW_THREADS if (cudaSuccess != cerr)
if (CUBLAS_STATUS_SUCCESS != cerr)
{ {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"CUBLAS error '%s' while copying %lli data element" "Cuda error '%s' while copying %lli data element"
" to device memory", " to device memory",
cublasGetErrorString(cerr), cudaGetErrorString(cerr),
(long long)py_src_size); (long long)py_src_size);
Py_DECREF(py_src); Py_DECREF(py_src);
return -1; return -1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论