提交 c89d9737 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5257 from khaotik/gpu_gemv_speedup

GPU gemv speedup
...@@ -2962,11 +2962,11 @@ CudaNdarray_select_a_gpu(PyObject* _unused, PyObject* dummy) ...@@ -2962,11 +2962,11 @@ CudaNdarray_select_a_gpu(PyObject* _unused, PyObject* dummy)
for (int device = 0; device < num_gpus; device++) { for (int device = 0; device < num_gpus; device++) {
cudaSetDevice(device); cudaSetDevice(device);
err = cudaDeviceSynchronize(); // << CUDA context gets created here. err = cudaDeviceSynchronize(); // << CUDA context gets created here.
cudaGetLastError(); // reset the error state cudaGetLastError(); // reset the error state
if (cudaSuccess == err) if (cudaSuccess == err)
break; break;
} }
if (cudaSuccess != err){ if (cudaSuccess != err){
printf("ERR!\\n"); printf("ERR!\\n");
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
...@@ -4392,15 +4392,31 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -4392,15 +4392,31 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
int sb_0 = (CudaNdarray_HOST_DIMS(B)[0] > 1) ? CudaNdarray_HOST_STRIDES(B)[0] : 1; int sb_0 = (CudaNdarray_HOST_DIMS(B)[0] > 1) ? CudaNdarray_HOST_STRIDES(B)[0] : 1;
int sc_0 = (CudaNdarray_HOST_DIMS(C)[0] > 1) ? CudaNdarray_HOST_STRIDES(C)[0] : 1; int sc_0 = (CudaNdarray_HOST_DIMS(C)[0] > 1) ? CudaNdarray_HOST_STRIDES(C)[0] : 1;
if (sa_0 == 0) if (sa_0 == 0) sa_0 = 1;
sa_0 = 1; if (sa_1 == 0) sa_1 = 1;
if (sa_1 == 0)
sa_1 = 1; int used_dot = 0;
// This is important because we can end up not calling Sgemv at all // This is important because we can end up not calling Sgemv at all
cublasStatus_t err = CUBLAS_STATUS_SUCCESS; cublasStatus_t err = CUBLAS_STATUS_SUCCESS;
if (CudaNdarray_SIZE(C)) { if (CudaNdarray_SIZE(C)) {
if ((CudaNdarray_HOST_DIMS(A)[0] <= 1) // A is row vector & alpha==1 & beta==0 -> use cublasSdot
if (CudaNdarray_HOST_DIMS(A)[0] == 1 && alpha==1.f && beta==0.f) {
//replace this with custom inner product kernel with alpha and beta parameter?
cublasPointerMode_t pmode;
//set pointer mode to make sure cublas not storing on host pointer
cublasGetPointerMode(handle, &pmode);
cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE);
err = cublasSdot(
handle, CudaNdarray_HOST_DIMS(A)[1],
CudaNdarray_DEV_DATA(A), sa_1,
CudaNdarray_DEV_DATA(B), sb_0,
CudaNdarray_DEV_DATA(C));
cublasSetPointerMode(handle, pmode);
used_dot = 1;
}
// A is row-contiguous | row vector
else if ((CudaNdarray_HOST_DIMS(A)[0] <= 1)
|| ((CudaNdarray_HOST_STRIDES(A)[0] == 1) || ((CudaNdarray_HOST_STRIDES(A)[0] == 1)
&& (CudaNdarray_HOST_STRIDES(A)[1] > 0))) && (CudaNdarray_HOST_STRIDES(A)[1] > 0)))
{ {
...@@ -4412,6 +4428,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -4412,6 +4428,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
&beta, &beta,
CudaNdarray_DEV_DATA(C), sc_0); CudaNdarray_DEV_DATA(C), sc_0);
} }
// A is column-contiguous | column vector
else if ((CudaNdarray_HOST_DIMS(A)[1] <= 1) else if ((CudaNdarray_HOST_DIMS(A)[1] <= 1)
|| ((CudaNdarray_HOST_STRIDES(A)[1] == 1) || ((CudaNdarray_HOST_STRIDES(A)[1] == 1)
&& (CudaNdarray_HOST_STRIDES(A)[0] > 0))) && (CudaNdarray_HOST_STRIDES(A)[0] > 0)))
...@@ -4424,6 +4441,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -4424,6 +4441,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
&beta, &beta,
CudaNdarray_DEV_DATA(C), sc_0); CudaNdarray_DEV_DATA(C), sc_0);
} }
// A is non vector and have malformed strides
else else
{ {
PyErr_Format(PyExc_AssertionError, PyErr_Format(PyExc_AssertionError,
...@@ -4449,9 +4467,16 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -4449,9 +4467,16 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
if (CUBLAS_STATUS_SUCCESS != err) if (CUBLAS_STATUS_SUCCESS != err)
{ {
if (!used_dot)
{
PyErr_Format(PyExc_RuntimeError,
"cublasSgemv failed (%i)",
err);
} else {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"cublasSgemv failed (%i)", "cublasSdot failed (%i)",
err); err);
}
return -1; return -1;
} }
return 0; return 0;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论