提交 8850e88d authored 作者: Ciyong Chen's avatar Ciyong Chen

add blas threads function declaration and restore blas threads after foring to 1

上级 b834babe
...@@ -961,6 +961,49 @@ def blas_header_text(): ...@@ -961,6 +961,49 @@ def blas_header_text():
return header return header
def mkl_threads_text():
"""C header for MKL threads interface"""
header = """
extern "C"
{
int MKL_Set_Num_Threads_Local(int);
#define mkl_set_num_threads_local MKL_Set_Num_Threads_Local
void MKL_Set_Num_Threads(int);
#define mkl_set_num_threads MKL_Set_Num_Threads
int MKL_Get_Max_Threads(void);
#define mkl_get_max_threads MKL_Get_Max_Threads
int MKL_Domain_Set_Num_Threads(int, int);
#define mkl_domain_set_num_threads MKL_Domain_Set_Num_Threads
int MKL_Domain_Get_Max_Threads(int);
#define mkl_domain_get_max_threads MKL_Domain_Get_Max_Threads
void MKL_Set_Dynamic(int);
#define mkl_set_dynamic MKL_Set_Dynamic
int MKL_Get_Dynamic(void);
#define mkl_get_dynamic MKL_Get_Dynamic
}
"""
return header
def openblas_threads_text():
"""C header for OpenBLAS threads interface"""
header = """
extern "C"
{
void openblas_set_num_threads(int);
void goto_set_num_threads(int);
int openblas_get_num_threads(void);
}
"""
return header
def blas_header_version(): def blas_header_version():
# Version for the base header # Version for the base header
version = (1,) version = (1,)
......
...@@ -9,7 +9,7 @@ from theano import Apply ...@@ -9,7 +9,7 @@ from theano import Apply
from theano import gof from theano import gof
from theano.tensor import as_tensor_variable, TensorType from theano.tensor import as_tensor_variable, TensorType
from theano.tensor.nnet.abstract_conv import get_conv_output_shape from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.blas_headers import blas_header_text from theano.tensor import blas_headers
from theano.tensor.blas import ldflags, blas_header_version from theano.tensor.blas import ldflags, blas_header_version
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -86,7 +86,12 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -86,7 +86,12 @@ class BaseCorrMM(gof.OpenMPOp):
str(self.filter_dilation)) str(self.filter_dilation))
def c_support_code(self): def c_support_code(self):
return blas_header_text() ccodes = blas_headers.blas_header_text()
if self.blas_type == 'openblas':
ccodes += blas_headers.openblas_threads_text()
elif self.blas_type == 'mkl':
ccodes += blas_headers.mkl_threads_text()
return ccodes
def c_libraries(self): def c_libraries(self):
return ldflags() return ldflags()
...@@ -105,10 +110,6 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -105,10 +110,6 @@ class BaseCorrMM(gof.OpenMPOp):
def c_headers(self): def c_headers(self):
headers = ['<stdio.h>'] headers = ['<stdio.h>']
headers += super(BaseCorrMM, self).c_headers() headers += super(BaseCorrMM, self).c_headers()
if self.blas_type == 'openblas':
headers += ['cblas.h']
if self.blas_type == 'mkl':
headers += ['mkl.h']
return headers return headers
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -137,20 +138,22 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -137,20 +138,22 @@ class BaseCorrMM(gof.OpenMPOp):
if self.openmp: if self.openmp:
sub['omp_flags'] = '#pragma omp parallel for schedule(static)' sub['omp_flags'] = '#pragma omp parallel for schedule(static)'
sub['omp_max_threads'] = 'omp_get_max_threads()' sub['omp_max_threads'] = 'omp_get_max_threads()'
sub['omp_set_threads'] = 'omp_set_num_threads' sub['set_omp_threads'] = 'omp_set_num_threads'
sub['omp_get_threads'] = 'omp_get_thread_num()' sub['get_omp_threads'] = 'omp_get_thread_num()'
else:
sub['omp_flags'] = ''
sub['omp_max_threads'] = 1
sub['omp_set_threads'] = ''
sub['omp_get_threads'] = 0
if self.blas_type == 'openblas': if self.blas_type == 'openblas':
sub['blas_flags'] = 'openblas_set_num_threads(1)' sub['set_blas_threads'] = 'openblas_set_num_threads'
sub['get_blas_threads'] = 'openblas_get_num_threads()'
elif self.blas_type == 'mkl': elif self.blas_type == 'mkl':
sub['blas_flags'] = 'mkl_set_num_threads(1)' sub['set_blas_threads'] = 'mkl_set_num_threads'
sub['get_blas_threads'] = 'mkl_get_max_threads()'
else: else:
sub['blas_flags'] = '' sub['omp_flags'] = ''
sub['omp_max_threads'] = '1'
sub['set_omp_threads'] = ''
sub['get_omp_threads'] = '0'
sub['set_blas_threads'] = ''
sub['get_blas_threads'] = '0'
files = ['corr_gemm.c'] files = ['corr_gemm.c']
codes = [open(os.path.join(os.path.split(__file__)[0], f)).read() codes = [open(os.path.join(os.path.split(__file__)[0], f)).read()
......
...@@ -184,7 +184,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -184,7 +184,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
} }
// Create temporary columns // Create temporary columns
const int max_threads = %(omp_max_threads)s < batchSize ? %(omp_max_threads)s : batchSize; int max_threads = %(omp_max_threads)s;
if (batchSize < max_threads) {
max_threads = batchSize;
}
npy_intp col_dim[3]; npy_intp col_dim[3];
col_dim[0] = (npy_intp)max_threads; col_dim[0] = (npy_intp)max_threads;
col_dim[1] = (npy_intp)(nChannels * kW * kH); col_dim[1] = (npy_intp)(nChannels * kW * kH);
...@@ -216,22 +219,23 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -216,22 +219,23 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
char Trans = 'T'; char Trans = 'T';
PyArrayObject *output; PyArrayObject *output;
%(omp_set_threads)s(max_threads); %(set_omp_threads)s(max_threads);
if (direction == 0) { // forward pass if (direction == 0) { // forward pass
output = top; output = top;
// valid correlation: im2col, then gemm // valid correlation: im2col, then gemm
// Iterate over batch // Iterate over batch
int blas_threads_saved = %(get_blas_threads)s;
// Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance.
%(set_blas_threads)s(1);
%(omp_flags)s %(omp_flags)s
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_threads)s; int tid = %(get_omp_threads)s;
// First, im2col // First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight, im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW, bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
// Second, gemm // Second, gemm
// Always forcing gemm to one thread here for best and stable performance.
%(blas_flags)s;
%(gemm)s(&NTrans, &NTrans, %(gemm)s(&NTrans, &NTrans,
&N_, &M_, &K_, &N_, &M_, &K_,
&one, &one,
...@@ -240,6 +244,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -240,6 +244,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
&zero, &zero,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_); (%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_);
} }
// Restore to previous blas threads
%(set_blas_threads)s(blas_threads_saved);
/* /*
// Original caffe code for comparison // Original caffe code for comparison
...@@ -291,10 +297,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -291,10 +297,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
// valid convolution: im2col, then gemm // valid convolution: im2col, then gemm
// Iterate over batch // Iterate over batch
int blas_threads_saved = %(get_blas_threads)s;
// Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance.
%(set_blas_threads)s(1);
// OMP for batch-level paralization // OMP for batch-level paralization
%(omp_flags)s %(omp_flags)s
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_threads)s; int tid = %(get_omp_threads)s;
// First, im2col // First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight, im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW, bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
...@@ -303,8 +312,6 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -303,8 +312,6 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.) // is faster than setting weight to all zeros before the loop.)
// Always forcing gemm to one thread here for best and stable performance.
%(blas_flags)s;
%(gemm)s(&Trans, &NTrans, %(gemm)s(&Trans, &NTrans,
&K_, &M_, &N_, &K_, &M_, &N_,
&one, &one,
...@@ -314,6 +321,9 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -314,6 +321,9 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
(%(float_type)s*)PyArray_DATA(local_weight) + (%(float_type)s*)PyArray_DATA(local_weight) +
tid * weight_dim[1], &K_); tid * weight_dim[1], &K_);
} }
// Restore to previous blas threads
%(set_blas_threads)s(blas_threads_saved);
//aggregate weights //aggregate weights
memset((%(float_type)s*)PyArray_DATA(weight), 0, M_ * K_*sizeof(%(float_type)s)); memset((%(float_type)s*)PyArray_DATA(weight), 0, M_ * K_*sizeof(%(float_type)s));
/* /*
...@@ -365,12 +375,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -365,12 +375,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
// full convolution: gemm, then col2im // full convolution: gemm, then col2im
// Iterate over batch // Iterate over batch
int blas_threads_saved = %(get_blas_threads)s;
// Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance.
%(set_blas_threads)s(1);
%(omp_flags)s %(omp_flags)s
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
// gemm into columns // gemm into columns
int tid = %(omp_get_threads)s; int tid = %(get_omp_threads)s;
// Always forcing gemm to one thread here for best and stable performance.
%(blas_flags)s;
%(gemm)s(&NTrans, &Trans, %(gemm)s(&NTrans, &Trans,
&N_, &K_, &M_, &N_, &K_, &M_,
&one, &one,
...@@ -383,6 +394,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -383,6 +394,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
kH, kW, dilH, dilW, padH, padW, kH, kW, dilH, dilW, padH, padW,
dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride); dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride);
} }
// Restore to previous blas threads
%(set_blas_threads)s(blas_threads_saved);
/* /*
// Original caffe code for comparison // Original caffe code for comparison
// Note that this code was translated from the Theano GPU code, // Note that this code was translated from the Theano GPU code,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论