提交 e76a29d9 authored 作者: f0k's avatar f0k

Adds caffe's implementation of the full convolution to CorrMM; cleaning up and…

Adds caffe's implementation of the full convolution to CorrMM; cleaning up and documenting the code on the way
上级 f6bf2943
......@@ -501,29 +501,59 @@ gpu_ger_inplace = GpuGer(inplace=True)
class GpuCorrMM(GpuOp):
"""GPU correlation implementation using Matrix Multiply.
"""GPU correlation/convolution implementation using Matrix Multiplication.
:note: It don't implement the grad. So you should use it by
enabling the Theano flag ``optimizer_including=conv_gemm`` and
use :func:`conv2d <theano.tensor.nnet.conv.conv2d>`.
:note: It doesn't implement the grad. So you shouldn't use it directly, but
use :func:`conv2d <theano.tensor.nnet.conv.conv2d>` and then enable the
Theano flag ``optimizer_including=conv_gemm`` to automatically replace
all convolution operations with `GpuCorrMM`.
"""
def __init__(self, border_mode,
subsample=(1, 1),
pad=0):
pad=(0, 0)):
"""
:param border_mode: "valid" or "full"
:param subsample: the subsample operation applied on each output image.
:param subsample: the subsample operation applied to each output image.
Should be a tuple with 2 elements.
(sv, sh) is equivalent to GpuCorrMM(...)(...)[:,:,::sv, ::sh]
:param pad: not yet supported
If border_mode="full", this is instead treated as an upsampling
operation applied to each input image.
Set to (1, 1) to disable downsampling/upsampling.
:param pad: the width of a border of implicit zeros to pad the input
image with. Should be a tuple with 2 elements giving the numbers of
rows and columns to pad on each side, or "auto" to set the padding
to (kernel_rows - 1, kernel_columns - 1) at runtime.
If border_mode="full", this is instead treated as the width of a
border to crop from the output image.
Set to (0, 0) to disable padding/cropping.
:note: The border_mode changes the meaning of several parameters.
If border_mode="valid", the Op does a valid correlation of a padded
input image and subsamples it. (To perform a convolution instead,
you will need to flip the kernels.)
If border_mode="full", the Op does a full convolution of an
upsampled input image and crops it. (This can be used as a backward
pass of the valid correlation done with border_mode="valid".)
Combined with pad="auto", you can use border_mode="valid" to
simulate a full correlation with subsampling, or border_mode="full"
to simulate a valid convolution with upsampling.
:note: Currently, the Op requires a very specific memory layout.
For border_mode="valid", inputs, filters and outputs must be
C-contiguous. For border_mode="full", the same applies, except that
the strides of the first two dimensions of the filters (output and
input channels) must be swapped compared to C-contiguity.
"""
self.border_mode = border_mode
self.subsample = subsample
#if (border_mode == "full") and (subsample != (1,1)):
# raise NotImplementedError(
# "GpuCorrMM doesn't support subsampling for border_mode='full'")
self.pad = pad
if pad != 0:
raise NotImplementedError(
"GpuCorrMM don't implement the pad parameter")
#if (border_mode == "full") and (pad != (0,0)):
# raise NotImplementedError(
# "GpuCorrMM doesn't support padding for border_mode='full'")
def __eq__(self, other):
return type(self) == type(other) \
......@@ -540,7 +570,7 @@ class GpuCorrMM(GpuOp):
^ hash(self.pad)
def __str__(self):
return '%s{%s, %s, pad=%d}' % (
return '%s{%s, %s, pad=%r}' % (
self.__class__.__name__,
self.border_mode,
str(self.subsample),
......@@ -581,7 +611,7 @@ class GpuCorrMM(GpuOp):
def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files
return (0, 22)
return (0, 23)
def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of
......@@ -596,13 +626,18 @@ class GpuCorrMM(GpuOp):
out, = out_
dx = self.subsample[0]
dy = self.subsample[1]
sub = sub.copy()
pad = self.pad
if self.pad == "auto":
padH = padW = -1
else:
padH = self.pad[0]
padW = self.pad[1]
if self.border_mode == "valid":
bmode = 1
else:
assert self.border_mode == "full"
elif self.border_mode == "full":
bmode = 0
else:
raise ValueError("mode must be one of 'full' or 'valid'")
sub = sub.copy()
sub.update(locals())
return """
......@@ -612,33 +647,34 @@ class GpuCorrMM(GpuOp):
//Optional args
int dx = %(dx)s;
int dy = %(dy)s;
int padH = 0;
int padW = 0;
int padH = %(padH)s;
int padW = %(padW)s;
CudaNdarray * img = %(img)s;
CudaNdarray * kern = %(kern)s;
CudaNdarray * out2 = NULL;
//TODO: Send self.pad, stride, etc
//Auto-padding if requested
if (padH < 0) {
padH = CudaNdarray_HOST_DIMS(kern)[2] - 1;
}
if (padW < 0) {
padW = CudaNdarray_HOST_DIMS(kern)[3] - 1;
}
int out_dim[4];
out_dim[0] = CudaNdarray_HOST_DIMS(img)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(kern)[0];
int logical_rows, logical_cols;
if (mode == 1)
if (mode == 1) // valid correlation with padding and subsampling
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] - CudaNdarray_HOST_DIMS(kern)[2] + 1;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] - CudaNdarray_HOST_DIMS(kern)[3] + 1;
out_dim[2] = ceil_intdiv(CudaNdarray_HOST_DIMS(img)[2] + 2*padH - CudaNdarray_HOST_DIMS(kern)[2] + 1, dx);
out_dim[3] = ceil_intdiv(CudaNdarray_HOST_DIMS(img)[3] + 2*padW - CudaNdarray_HOST_DIMS(kern)[3] + 1, dy);
}
else
else // full convolution with upsampling and cropping
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] + CudaNdarray_HOST_DIMS(kern)[2] - 1;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] + CudaNdarray_HOST_DIMS(kern)[3] - 1;
padH = CudaNdarray_HOST_DIMS(kern)[2] - 1;
padW = CudaNdarray_HOST_DIMS(kern)[3] - 1;
out_dim[2] = (CudaNdarray_HOST_DIMS(img)[2] - 1) * dx + CudaNdarray_HOST_DIMS(kern)[2] - 2*padH;
out_dim[3] = (CudaNdarray_HOST_DIMS(img)[3] - 1) * dy + CudaNdarray_HOST_DIMS(kern)[3] - 2*padW;
}
out_dim[2] = ceil_intdiv(logical_rows, dx);
out_dim[3] = ceil_intdiv(logical_cols, dy);
if ( !(%(out)s
&& %(out)s->nd==4
......@@ -650,10 +686,9 @@ class GpuCorrMM(GpuOp):
{
Py_XDECREF(%(out)s);
%(out)s = (CudaNdarray*)CudaNdarray_NewDims(4,out_dim);
}
out2 = corrMM(%(img)s, %(kern)s, %(out)s, dx, dy, padH, padW);
out2 = corrMM(%(img)s, %(kern)s, %(out)s, mode, dx, dy, padH, padW);
if (out2==NULL){
%(fail)s
}
......
/*
Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_
#include <cublas_v2.h>
#include <cuda.h>
#include <driver_types.h> // cuda driver types
// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif
// CUDA: number of blocks for threads.
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
#endif // CAFFE_COMMON_HPP_
// This uses a lot of code from Caffe (http://caffe.berkeleyvision.org/);
// sources are clearly marked. Below we reproduce the original license of
// the Caffe software.
/*
Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.
......@@ -22,176 +25,351 @@ ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
// Reference code: https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu
#undef _GLIBCXX_ATOMIC_BUILTINS
#include <Python.h>
#include "cuda_ndarray.cuh"
#include "caffe_common.hpp"
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/caffe_common.hpp)
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// Use 1024 threads per block, which requires cuda sm_2x or above
const int CUDA_NUM_THREADS = 1024;
// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const int CUDA_NUM_THREADS = 1024;
#else
const int CUDA_NUM_THREADS = 512;
#endif
// CUDA: number of blocks for threads.
inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
// Kernel for fast unfold+copy
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu)
// Kernels for fast unfold + copy
__global__ void im2col_kernel(const int n, const float* data_im,
const int height, const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w, const int height_col, const int width_col,
float* data_col) {
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int height_col, const int width_col,
float* data_col) {
CUDA_KERNEL_LOOP(index, n) {
int w_out = index % width_col;
index /= width_col;
int h_out = index % height_col;
int channel_in = index / height_col;
int channel_out = channel_in * ksize_h * ksize_w;
int h_index = index / width_col;
int h_out = h_index % height_col;
int channel_in = h_index / height_col;
int channel_out = channel_in * kernel_h * kernel_w;
int h_in = h_out * stride_h - pad_h;
int w_in = w_out * stride_w - pad_w;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
data_im += (channel_in * height + h_in) * width + w_in;
for (int i = 0; i < ksize_h; ++i) {
for (int j = 0; j < ksize_w; ++j) {
float* data_col_ptr = data_col;
data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out;
const float* data_im_ptr = data_im;
data_im_ptr += (channel_in * height + h_in) * width + w_in;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
int h = h_in + i;
int w = w_in + j;
*data_col = (h >= 0 && w >= 0 && h < height && w < width) ?
data_im[i * width + j] : 0;
data_col += height_col * width_col;
*data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ?
data_im_ptr[i * width + j] : 0;
data_col_ptr += height_col * width_col;
}
}
}
}
void im2col(const float* data_im, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w, float* data_col) {
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
float* data_col) {
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
int height_col = (height + 2 * pad_h - ksize_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - ksize_w) / stride_w + 1;
int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
int num_kernels = channels * height_col * width_col;
// Launch
im2col_kernel <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>> (
num_kernels, data_im, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w,
height_col, width_col, data_col
);
im2col_kernel<<<GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h,
pad_w, stride_h, stride_w, height_col,
width_col, data_col);
}
__global__ void col2im_kernel(const int n, const float* data_col,
const int height, const int width, const int channels,
const int patch_h, const int patch_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int height_col, const int width_col,
float* data_im) {
CUDA_KERNEL_LOOP(index, n) {
float val = 0;
int w = index % width + pad_w;
int h = (index / width) % height + pad_h;
int c = index / (width * height);
// compute the start and end of the output
int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1;
int w_col_end = min(w / stride_w + 1, width_col);
int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1;
int h_col_end = min(h / stride_h + 1, height_col);
/*
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = c * patch_h * patch_w + (h - h_col * stride_h) * ksize
+ (w - w_col * stride_w);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
}
}
*/
// equivalent implementation
int offset =
(c * patch_h * patch_w + h * patch_w + w) * height_col * width_col;
int coeff_h_col = (1 - stride_h * patch_w * height_col) * width_col;
int coeff_w_col = (1 - stride_w * height_col * width_col);
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
}
}
data_im[index] = val;
}
}
void col2im(const float* data_col, const int channels,
const int height, const int width, const int patch_h, const int patch_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, float* data_im) {
int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
int num_kernels = channels * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im_kernel<<<GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels, data_col, height, width, channels, patch_h, patch_w,
pad_h, pad_w, stride_h, stride_w,
height_col, width_col, data_im);
}
// Author: Arjun Jain
// Theano op code
// Authors: Arjun Jain, Frédéric Bastien, Jan Schlüter
// Reference code: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
// and https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu
CudaNdarray* corrMM(const CudaNdarray *input,
CudaNdarray *weight,
CudaNdarray *output,
int dH = 1,
int dW = 1,
int padH = 0,
int padW = 0)
CudaNdarray *weight,
CudaNdarray *output,
int mode,
int dH = 1,
int dW = 1,
int padH = 0,
int padW = 0)
{
cublasStatus_t status;
if (input->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "required input of 4D");
PyErr_SetString(PyExc_ValueError, "GpuCorrMM requires input of 4D");
}
if (weight->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "required weight of 4D");
PyErr_SetString(PyExc_ValueError, "GpuCorrMM requires weight of 4D");
}
if (output->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "GpuCorrMM requires output of 4D");
}
// Extract some shape information for later and check shape consistency
// inputs: (batchSize, nInputPlane, inputHeight, inputWidth)
const int batchSize = CudaNdarray_HOST_DIMS(input)[0];
const int nInputPlane = CudaNdarray_HOST_DIMS(input)[1];
const int inputHeight = CudaNdarray_HOST_DIMS(input)[2];
const int inputWidth = CudaNdarray_HOST_DIMS(input)[3];
// filters: (nOutputPlane, nInputPlane, rows, columns)
const int nOutputPlane = CudaNdarray_HOST_DIMS(weight)[0];
const int kH = CudaNdarray_HOST_DIMS(weight)[2];
const int kW = CudaNdarray_HOST_DIMS(weight)[3];
if (nInputPlane != CudaNdarray_HOST_DIMS(weight)[1]) {
PyErr_SetString(PyExc_ValueError,
"GpuCorrMM images and kernel must have the same stack size\n");
return NULL;
}
// outputs: (batchSize, nOutputPlane, outputHeight, outputWidth)
int outputHeight, outputWidth;
if (mode == 1) { // valid correlation with padding and subsampling
outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
}
else if (mode == 0) { // full convolution with upsampling and cropping
// these would be the shapes for a standard full convolution:
//outputHeight = (inputHeight + 2*padH + kH - 2) / dH + 1;
//outputWidth = (inputWidth + 2*padW + kW - 2) / dW + 1;
// but here, dH and dW are *upsampling* factors, and padding is reversed
// (because the implementation was meant as a backward pass for a CNN)
outputHeight = (inputHeight - 1) * dH + kH - 2*padH;
outputWidth = (inputWidth - 1) * dW + kW - 2*padW;
}
if (batchSize != CudaNdarray_HOST_DIMS(output)[0] ||
nOutputPlane != CudaNdarray_HOST_DIMS(output)[1] ||
outputHeight != CudaNdarray_HOST_DIMS(output)[2] ||
outputWidth != CudaNdarray_HOST_DIMS(output)[3]) {
PyErr_Format(PyExc_ValueError,
"GpuCorrMM output parameter has wrong shape %d %d %d %d, expected %d %d %d %d\n",
CudaNdarray_HOST_DIMS(output)[0], CudaNdarray_HOST_DIMS(output)[1],
CudaNdarray_HOST_DIMS(output)[2], CudaNdarray_HOST_DIMS(output)[3],
batchSize, nOutputPlane, outputHeight, outputWidth);
return NULL;
}
if (mode == 1) { // valid correlation: im2col, then gemm
// Create temporary columns (col_data)
int col_dim[2];
col_dim[0] = nInputPlane * kW * kH;
col_dim[1] = outputHeight * outputWidth;
CudaNdarray* col_data = (CudaNdarray*)CudaNdarray_NewDims(2, col_dim);
// Define some useful variables
const int ip_stride = CudaNdarray_HOST_STRIDES(input)[0];
const int op_stride = CudaNdarray_HOST_STRIDES(output)[0];
const int K_ = col_dim[0];
const int N_ = col_dim[1];
const int M_ = nOutputPlane;
const float alpha = 1.0f;
const float beta = 0.0f;
int kH = CudaNdarray_HOST_DIMS(weight)[2];
int kW = CudaNdarray_HOST_DIMS(weight)[3];
int nInputPlane = CudaNdarray_HOST_DIMS(input)[1];
// filters: (number of filters, nInputPlane, rows, columns)
int nOutputPlane = CudaNdarray_HOST_DIMS(weight)[0];
long batchSize = CudaNdarray_HOST_DIMS(input)[0];
if (CudaNdarray_HOST_DIMS(input)[1] != CudaNdarray_HOST_DIMS(weight)[1]){
PyErr_SetString(PyExc_ValueError,
"GpuCorrMM images and kernel must have the same stack size\n"
);
return NULL;
}
long inputHeight = CudaNdarray_HOST_DIMS(input)[2];
long inputWidth = CudaNdarray_HOST_DIMS(input)[3];
long outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
long outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
// check output, size (batchSize, nOutputPlane,
// outputHeight, outputWidth);
if (batchSize != CudaNdarray_HOST_DIMS(output)[0] ||
nOutputPlane != CudaNdarray_HOST_DIMS(output)[1] ||
outputHeight != CudaNdarray_HOST_DIMS(output)[2] ||
outputWidth != CudaNdarray_HOST_DIMS(output)[3]){
PyErr_Format(
PyExc_ValueError,
"GpuCorrMM outputs parameter don't have the good shape %d %d %d %d, %d %d %d %d\n",
batchSize, nOutputPlane, outputHeight, outputWidth,
CudaNdarray_HOST_DIMS(output)[0], CudaNdarray_HOST_DIMS(output)[1],
CudaNdarray_HOST_DIMS(output)[2], CudaNdarray_HOST_DIMS(output)[3]);
return NULL;
}
// Create temporary columns
int col_dim[2];
col_dim[0] = nInputPlane*kW*kH;
col_dim[1]= outputHeight*outputWidth;
CudaNdarray* columns = (CudaNdarray*)CudaNdarray_NewDims(2,col_dim);
int ip_stride = CudaNdarray_HOST_STRIDES(input)[0];
int op_stride = CudaNdarray_HOST_STRIDES(output)[0];
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt ++) {
// Matrix mulitply per output:
// 1. Extract columns:
im2col(
input->devdata + elt*ip_stride,
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
columns->devdata
);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
float alpha = 1.0f; float beta = 0.0f;
int m = CudaNdarray_HOST_DIMS(columns)[1];
int n = CudaNdarray_HOST_DIMS(weight)[0];
int k = CudaNdarray_HOST_DIMS(columns)[0];
status = cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
m, n, k,
&alpha,
columns->devdata, m,
weight->devdata, k,
&beta,
output->devdata + elt * op_stride, m
);
if (status != CUBLAS_STATUS_SUCCESS) {
std::cerr << "!!!! CUBLAS error: ";
std::cerr << cublasGetErrorString(status) << "\n";
}
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
// First, im2col
im2col(input->devdata + n * ip_stride, nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, col_data->devdata);
// Second, gemm
cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
N_, M_, K_,
&alpha,
col_data->devdata, N_,
weight->devdata, K_,
&beta,
output->devdata + n * op_stride, N_);
if (status != CUBLAS_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUBLAS error: %s\n",
cublasGetErrorString(status));
return NULL;
}
}
// Free temporary columns
Py_DECREF(col_data);
}
/*
// Original caffe code for comparison
// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
// Note that this is for grouped convolution; we can ignore groups
const Dtype* bottom_data = bottom[i]->gpu_data();
Dtype* top_data = (*top)[i]->mutable_gpu_data();
Dtype* col_data = col_buffer_.mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
for (int n = 0; n < num_; ++n) {
// First, im2col
im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_,
width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
col_data);
// Second, innerproduct with groups
for (int g = 0; g < group_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
(Dtype)1., weight + weight_offset * g, col_data + col_offset * g,
(Dtype)0., top_data + (*top)[i]->offset(n) + top_offset * g);
== (see https://github.com/BVLC/caffe/blob/master/src/caffe/util/math_functions.cu#L16)
cublasSgemm(CUBLAS_OP_N, CUBLAS_OP_N,
N_, M_, K_,
1.,
col_data + col_offset * g, N_,
weight + weight_offset * g, K_,
0.,
top_data + (*top)[i]->offset(n) + top_offset * g, N_);
}
}
*/
}
else if (mode == 0) { // full convolution: gemm, then col2im
// Create temporary columns (col_diff)
int col_dim[2];
col_dim[0] = nOutputPlane * kW * kH;
col_dim[1] = inputHeight * inputWidth;
CudaNdarray* col_diff = (CudaNdarray*)CudaNdarray_NewDims(2, col_dim);
// Define some useful variables
const int ip_stride = CudaNdarray_HOST_STRIDES(input)[0];
const int op_stride = CudaNdarray_HOST_STRIDES(output)[0];
const int K_ = col_dim[0];
const int N_ = col_dim[1];
const int M_ = nInputPlane;
const float alpha = 1.0f;
const float beta = 0.0f;
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
// gemm into columns
cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_T,
N_, K_, M_,
&alpha,
input->devdata + n * ip_stride, N_,
weight->devdata, K_,
&beta,
col_diff->devdata, N_);
if (status != CUBLAS_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUBLAS error: %s\n",
cublasGetErrorString(status));
return NULL;
}
// col2im back to the data
col2im(col_diff->devdata, nOutputPlane, outputHeight, outputWidth,
kH, kW, padH, padW, dH, dW, output->devdata + n * op_stride);
}
// Free temporary columns
Py_DECREF(col_diff);
Py_DECREF(columns);
return output;
/*
// Original caffe code for comparison
// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
// Note that this is the backward pass of a valid convolution, so
// top_diff is the input, bottom_diff is the output, weights are weights
Dtype* col_data = col_buffer_.mutable_gpu_data();
Dtype* col_diff = col_buffer_.mutable_gpu_diff();
Dtype* bottom_diff = (*bottom)[i]->mutable_gpu_diff();
for (int n = 0; n < num_; ++n) {
// gradient w.r.t. bottom data, if necessary
for (int g = 0; g < group_; ++g) {
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
(Dtype)1., weight + weight_offset * g,
top_diff + top[i]->offset(n) + top_offset * g,
(Dtype)0., col_diff + col_offset * g);
== (see https://github.com/BVLC/caffe/blob/master/src/caffe/util/math_functions.cu#L16)
cublasSgemm(CUBLAS_OP_N, CUBLAS_OP_T, N_, K_, M_,
1.,
top_diff + top[i]->offset(n) + top_offset * g, N_,
weight + weight_offset * g, K_,
0.,
col_diff + col_offset * g, N_);
}
// col2im back to the data
col2im_gpu(col_diff, channels_, height_, width_,
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
bottom_diff + (*bottom)[i]->offset(n));
}
*/
}
return output;
}
......@@ -1351,10 +1351,22 @@ def local_conv_gemm(node):
if (isinstance(node.op, GpuConv) and
node.op.border_mode in ['full', 'valid']):
img, kern = node.inputs
border_mode = node.op.border_mode
subsample = node.op.subsample
pad = (0,0)
if (border_mode == 'full') and ((subsample != (1,1)) or (pad != (0,0))):
# need to simulate this via a padded valid convolution
pad = 'auto'
border_mode = 'valid'
if (border_mode == 'valid'):
# need to flip the kernel for valid convolution
kern = gpu_contiguous(kern[:, :, ::-1, ::-1])
elif (border_mode == 'full'):
# need to bring kernel into correct memory layout for full convolution
kern = gpu_contiguous(kern.dimshuffle(1, 0, 2, 3)).dimshuffle(1, 0, 2, 3)
# need C-contiguous inputs
img = gpu_contiguous(img)
kern = kern[:, :, ::-1, ::-1]
kern = gpu_contiguous(kern)
return [GpuCorrMM(node.op.border_mode, node.op.subsample)(img, kern)]
return [GpuCorrMM(border_mode, subsample, pad)(img, kern)]
gpu_optimizer.register("conv_gemm", local_conv_gemm)
......
......@@ -848,7 +848,7 @@ def test_gemm_directly():
input: (batch size, channels, rows, columns)
filters: (number of filters, channels, rows, columns)
"""
for mode in ['full', 'valid']:
for mode in ['valid']: # 'full' currently disabled; doesn't allow subsampling
print 'Testing mode: ' + mode
for bs in range(1, 5):
for ch in range(1,4):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论