提交 ca0c64e0 authored 作者: Ciyong Chen's avatar Ciyong Chen

corr_gemm optimization to improve CNN performance

上级 b9813e0b
......@@ -1873,7 +1873,8 @@ class GCC_compiler(Compiler):
if ('g++' not in theano.config.cxx and
'clang++' not in theano.config.cxx and
'clang-omp++' not in theano.config.cxx):
'clang-omp++' not in theano.config.cxx and
'icpc' not in theano.config.cxx):
_logger.warn(
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
" the g++ compiler. So we disable the compiler optimization"
......
......@@ -10,13 +10,14 @@ from theano import gof
from theano.tensor import as_tensor_variable, TensorType
from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas import ldflags
from theano.tensor.blas import ldflags, blas_header_version
from multiprocessing import cpu_count
_logger = logging.getLogger(__name__)
class BaseCorrMM(gof.Op):
class BaseCorrMM(gof.OpenMPOp):
"""
Base class for `CorrMM`, `CorrMM_gradWeights` and
`CorrMM_gradInputs`. Cannot be used directly.
......@@ -34,7 +35,8 @@ class BaseCorrMM(gof.Op):
__props__ = ('border_mode', 'subsample', 'filter_dilation')
def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1)):
filter_dilation=(1, 1), openmp=None):
super(BaseCorrMM, self).__init__(openmp=openmp)
if isinstance(border_mode, integer_types):
if border_mode < 0:
raise ValueError(
......@@ -82,7 +84,10 @@ class BaseCorrMM(gof.Op):
return ldflags()
def c_compile_args(self):
return ldflags(libs=False, flags=True)
compile_args = ldflags(libs=False, flags=True)
compile_args += super(BaseCorrMM, self).c_compile_args()
return compile_args
def c_lib_dirs(self):
return ldflags(libs=False, libs_dir=True)
......@@ -91,11 +96,13 @@ class BaseCorrMM(gof.Op):
return ldflags(libs=False, include_dir=True)
def c_headers(self):
return ['<stdio.h>']
headers = ['<stdio.h>']
headers += super(BaseCorrMM, self).c_headers()
return headers
def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files
return (1, 2)
return (1, self.openmp, blas_header_version())
def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of
......@@ -115,6 +122,17 @@ class BaseCorrMM(gof.Op):
sub['float_typenum'] = 'NPY_DOUBLE'
sub['n_bytes'] = 8
sub['c_float_type'] = 'double'
if self.openmp:
sub['cores'] = self.cores
sub['omp_flags'] = '#pragma omp parallel for'
sub['omp_set_threads'] = 'omp_set_num_threads'
sub['omp_get_threads'] = 'omp_get_thread_num()'
else:
sub['cores'] = 1
sub['omp_flags'] = ''
sub['omp_set_threads'] = ''
sub['omp_get_threads'] = 0
files = ['corr_gemm.c']
codes = [open(os.path.join(os.path.split(__file__)[0], f)).read()
for f in files]
......@@ -325,7 +343,7 @@ class BaseCorrMM(gof.Op):
else {
typenum = PyArray_TYPE(bottom);
}
%(out)s = (PyArrayObject*)PyArray_EMPTY(4,
%(out)s = (PyArrayObject*)PyArray_ZEROS(4,
out_dim,
typenum,
0);
......@@ -376,9 +394,6 @@ class CorrMM(BaseCorrMM):
Set to `(1, 1)` to disable filter dilation.
"""
def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1)):
super(CorrMM, self).__init__(border_mode, subsample, filter_dilation)
def make_node(self, img, kern):
img = as_tensor_variable(img)
......@@ -436,12 +451,6 @@ class CorrMM_gradWeights(BaseCorrMM):
"""
def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1)):
super(CorrMM_gradWeights, self).__init__(border_mode,
subsample,
filter_dilation)
def make_node(self, img, topgrad, shape=None):
img = as_tensor_variable(img)
topgrad = as_tensor_variable(topgrad)
......@@ -538,11 +547,6 @@ class CorrMM_gradInputs(BaseCorrMM):
"""
def __init__(self, border_mode="valid", subsample=(1, 1), filter_dilation=(1, 1)):
super(CorrMM_gradInputs, self).__init__(border_mode,
subsample,
filter_dilation)
def make_node(self, kern, topgrad, shape=None):
kern = as_tensor_variable(kern)
topgrad = as_tensor_variable(topgrad)
......
......@@ -26,7 +26,6 @@ ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cpp)
// Loops for fast unfold + copy
void im2col(const %(float_type)s* data_im, const int channels,
......@@ -185,51 +184,64 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
}
// Create temporary columns
npy_intp col_dim[2];
col_dim[0] = (npy_intp)(nChannels * kW * kH);
col_dim[1] = (npy_intp)(topHeight * topWidth);
PyArrayObject* col = (PyArrayObject*)PyArray_EMPTY(2,
col_dim,
PyArray_TYPE(top),
0);
if (NULL == col)
{
int max_threads = %(omp_get_max_threads)s;
if (batchSize < max_threads) {
max_threads = batchSize;
}
npy_intp col_dim[3];
col_dim[0] = (npy_intp)max_threads;
col_dim[1] = (npy_intp)(nChannels * kW * kH);
col_dim[2] = (npy_intp)(topHeight * topWidth);
//Change to PyArray_ZEROS which is faster than PyArray_EMPTY.
PyArrayObject* col = (PyArrayObject*)PyArray_ZEROS(3,
col_dim,
PyArray_TYPE(top),
0);
if (NULL == col) {
PyErr_Format(PyExc_RuntimeError,
"CorrMM failed to allocate working memory of %%ld x %%ld\n",
col_dim[0], col_dim[1]);
"CorrMM failed to allocate working memory of"
" %%ld x %%ld x %%ld\n",
col_dim[0], col_dim[1], col_dim[2]);
return NULL;
}
// Define some useful variables
const int bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f;
const int top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f;
const int K_ = col_dim[0];
const int N_ = col_dim[1];
const int K_ = col_dim[1];
const int N_ = col_dim[2];
const int col_stride = (K_ * N_);
const int M_ = nFilters;
const %(c_float_type)s one = 1.0;
const %(c_float_type)s zero = 0.0;
char NTrans = 'N';
char Trans = 'T';
PyArrayObject *output;
%(omp_set_threads)s(max_threads);
if (direction == 0) { // forward pass
output = top;
// valid correlation: im2col, then gemm
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
%(omp_flags)s
for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_threads)s;
// First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, dilH, dilW,
padH, padW, dH, dW, (%(float_type)s*)PyArray_DATA(col));
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
// Second, gemm
%(gemm)s(&NTrans, &NTrans,
&N_, &M_, &K_,
&one,
(%(float_type)s*)PyArray_DATA(col), &N_,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_,
&zero,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_);
}
/*
// Original caffe code for comparison
// Note that this code was translated from the Theano GPU code,
......@@ -264,13 +276,31 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
}
else if (direction == 1) { // backprop wrt. weights
output = weight;
npy_intp weight_dim[2];
weight_dim[0] = (npy_intp)max_threads;
weight_dim[1] = (npy_intp)(M_ * K_);
PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2,
weight_dim, PyArray_TYPE(weight), 0);
if (NULL == local_weight)
{
PyErr_Format(PyExc_RuntimeError,
"CorrMM failed to allocate weight memory of %%ld x %%ld\n",
weight_dim[0], weight_dim[1]);
return NULL;
}
local_weight = PyArray_GETCONTIGUOUS(local_weight);
// valid convolution: im2col, then gemm
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
// OMP for batch-level paralization
%(omp_flags)s
for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_threads)s;
// First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, dilH, dilW,
padH, padW, dH, dW, (%(float_type)s*)PyArray_DATA(col));
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
// Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This
......@@ -278,10 +308,25 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
%(gemm)s(&Trans, &NTrans,
&K_, &M_, &N_,
&one,
(%(float_type)s*)PyArray_DATA(col), &N_,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_,
(n == 0) ? &zero : &one,
(%(float_type)s*)PyArray_DATA(weight), &K_);
(%(float_type)s*)PyArray_DATA(local_weight) +
tid * weight_dim[1], &K_);
}
//aggregate weights
memset((%(float_type)s*)PyArray_DATA(weight), 0, M_ * K_*sizeof(%(float_type)s));
/*
* Put index "j" into outer loop to get the
* correct result when openmp is used.
*/
%(omp_flags)s
for(int j = 0; j < weight_dim[1]; ++j){
for(int i = 0; i < max_threads; ++i){
((%(float_type)s*)PyArray_DATA(weight))[j] +=
*((%(float_type)s*)PyArray_DATA(local_weight) +
i * weight_dim[1] + j);
}
}
/*
// Original caffe code for comparison
......@@ -318,17 +363,20 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
PyArray_FILLWBYTE(bottom, 0);
// full convolution: gemm, then col2im
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
%(omp_flags)s
for (int n = 0; n < batchSize; ++n) {
// gemm into columns
int tid = %(omp_get_threads)s;
%(gemm)s(&NTrans, &Trans,
&N_, &K_, &M_,
&one,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_,
&zero,
(%(float_type)s*)PyArray_DATA(col), &N_);
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_);
// col2im back to the data
col2im((%(float_type)s*)PyArray_DATA(col), nChannels, bottomHeight, bottomWidth,
col2im((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels, bottomHeight, bottomWidth,
kH, kW, dilH, dilW, padH, padW,
dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride);
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论