提交 725f4480 authored 作者: f0k's avatar f0k

GpuCorrMM_gradWeights: Saved a CUDA call

上级 a7f65a46
...@@ -311,24 +311,12 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -311,24 +311,12 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
else if (direction == 1) { // backprop wrt. weights else if (direction == 1) { // backprop wrt. weights
output = weight; output = weight;
// valid convolution: im2col, then gemm // valid convolution: im2col, then gemm
// Initialize target with zeros as we will accumulate into it
// (all kernels run on the null stream, so we don't need to synchronize)
cudaError_t err = cudaMemsetAsync(weight->devdata, 0,
sizeof(float) * M_ * K_);
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUDA error in cudaMemsetAsync: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
return NULL;
}
// Iterate over batch // Iterate over batch
for (int n = 0; n < batchSize; n++) { for (int n = 0; n < batchSize; n++) {
// First, im2col // First, im2col
im2col(bottom->devdata + n * bottom_stride, nChannels, bottomHeight, im2col(bottom->devdata + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, padH, padW, dH, dW, col->devdata); bottomWidth, kH, kW, padH, padW, dH, dW, col->devdata);
err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUDA error in im2col: %s\n" "GpuCorrMM encountered a CUDA error in im2col: %s\n"
...@@ -338,13 +326,16 @@ CudaNdarray* corrMM(CudaNdarray *const bottom, ...@@ -338,13 +326,16 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
return NULL; return NULL;
} }
// Second, gemm // Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.)
cublasStatus_t status = cublasSgemm(handle, cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_T, CUBLAS_OP_N, CUBLAS_OP_T, CUBLAS_OP_N,
K_, M_, N_, K_, M_, N_,
&one, &one,
col->devdata, N_, col->devdata, N_,
top->devdata + n * top_stride, N_, top->devdata + n * top_stride, N_,
&one, (n == 0) ? &zero : &one,
weight->devdata, K_); weight->devdata, K_);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论