提交 725f4480 authored 作者: f0k's avatar f0k

GpuCorrMM_gradWeights: Saved a CUDA call

上级 a7f65a46
......@@ -311,24 +311,12 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
else if (direction == 1) { // backprop wrt. weights
output = weight;
// valid convolution: im2col, then gemm
// Initialize target with zeros as we will accumulate into it
// (all kernels run on the null stream, so we don't need to synchronize)
cudaError_t err = cudaMemsetAsync(weight->devdata, 0,
sizeof(float) * M_ * K_);
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUDA error in cudaMemsetAsync: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
return NULL;
}
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
// First, im2col
im2col(bottom->devdata + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, padH, padW, dH, dW, col->devdata);
err = cudaGetLastError();
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUDA error in im2col: %s\n"
......@@ -338,13 +326,16 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
return NULL;
}
// Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.)
cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_T, CUBLAS_OP_N,
K_, M_, N_,
&one,
col->devdata, N_,
top->devdata + n * top_stride, N_,
&one,
(n == 0) ? &zero : &one,
weight->devdata, K_);
if (status != CUBLAS_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论