提交 89f46144 authored 作者: khaotik's avatar khaotik

call to cublasSdot in gemv when for row-vector matrix

上级 d1705c09
......@@ -2963,11 +2963,11 @@ CudaNdarray_select_a_gpu(PyObject* _unused, PyObject* dummy)
for (int device = 0; device < num_gpus; device++) {
cudaSetDevice(device);
err = cudaDeviceSynchronize(); // << CUDA context gets created here.
cudaGetLastError(); // reset the error state
cudaGetLastError(); // reset the error state
if (cudaSuccess == err)
break;
}
if (cudaSuccess != err){
printf("ERR!\\n");
PyErr_Format(PyExc_RuntimeError,
......@@ -4393,15 +4393,42 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
int sb_0 = (CudaNdarray_HOST_DIMS(B)[0] > 1) ? CudaNdarray_HOST_STRIDES(B)[0] : 1;
int sc_0 = (CudaNdarray_HOST_DIMS(C)[0] > 1) ? CudaNdarray_HOST_STRIDES(C)[0] : 1;
if (sa_0 == 0)
sa_0 = 1;
if (sa_1 == 0)
sa_1 = 1;
if (sa_0 == 0) sa_0 = 1;
if (sa_1 == 0) sa_1 = 1;
// This is important because we can end up not calling Sgemv at all
cublasStatus_t err = CUBLAS_STATUS_SUCCESS;
if (CudaNdarray_SIZE(C)) {
if ((CudaNdarray_HOST_DIMS(A)[0] <= 1)
// A is row vector & alpha==1 & beta==0 -> use cublasSdot
if (CudaNdarray_HOST_DIMS(A)[0] == 1 && alpha==1.f && beta==0.f) {
//TODO: this is a rather temporary solution
// 1. better temp solution:
// replace this with custom inner product kernel with
// alpha and beta parameter
// 2. permanant solution:
// define a new "InnerProduct" Op, add an optimization
// "gemv -> inner_prod", perhaps for CPU/GPU both
float* dev_dst = CudaNdarray_DEV_DATA(C)+1-sc_0;
cublasPointerMode_t pmode;
cublasGetPointerMode(handle, &pmode);
// need to store dot result on device here
cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE);
err = cublasSdot(
handle, CudaNdarray_HOST_DIMS(A)[1],
CudaNdarray_DEV_DATA(A), sa_1,
CudaNdarray_DEV_DATA(B), sb_0,
dev_dst);
cublasSetPointerMode(handle, pmode);
if (CUBLAS_STATUS_SUCCESS != err)
{
PyErr_Format(PyExc_RuntimeError,
"cublasSdot failed (%i)",
err);
return -1;
}
}
// A is row-contiguous | row vector
else if ((CudaNdarray_HOST_DIMS(A)[0] <= 1)
|| ((CudaNdarray_HOST_STRIDES(A)[0] == 1)
&& (CudaNdarray_HOST_STRIDES(A)[1] > 0)))
{
......@@ -4413,6 +4440,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
&beta,
CudaNdarray_DEV_DATA(C), sc_0);
}
// A is column-contiguous | column vector
else if ((CudaNdarray_HOST_DIMS(A)[1] <= 1)
|| ((CudaNdarray_HOST_STRIDES(A)[1] == 1)
&& (CudaNdarray_HOST_STRIDES(A)[0] > 0)))
......@@ -4425,6 +4453,7 @@ int CudaNdarray_sgemv(float alpha, const CudaNdarray * A, const CudaNdarray * B,
&beta,
CudaNdarray_DEV_DATA(C), sc_0);
}
// A is non vector and have malformed strides
else
{
PyErr_Format(PyExc_AssertionError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论