提交 cc3deb5e authored 作者: Ciyong Chen's avatar Ciyong Chen

rename omp and blas key name

上级 8850e88d
......@@ -137,23 +137,24 @@ class BaseCorrMM(gof.OpenMPOp):
if self.openmp:
sub['omp_flags'] = '#pragma omp parallel for schedule(static)'
sub['omp_max_threads'] = 'omp_get_max_threads()'
sub['set_omp_threads'] = 'omp_set_num_threads'
sub['get_omp_threads'] = 'omp_get_thread_num()'
sub['omp_get_max_threads'] = 'omp_get_max_threads()'
sub['omp_get_thread_num'] = 'omp_get_thread_num()'
if self.blas_type == 'openblas':
sub['set_blas_threads'] = 'openblas_set_num_threads'
sub['get_blas_threads'] = 'openblas_get_num_threads()'
sub['blas_set_num_threads'] = 'openblas_set_num_threads'
sub['blas_get_num_threads'] = 'openblas_get_num_threads()'
elif self.blas_type == 'mkl':
sub['set_blas_threads'] = 'mkl_set_num_threads'
sub['get_blas_threads'] = 'mkl_get_max_threads()'
sub['blas_set_num_threads'] = 'mkl_set_num_threads'
sub['blas_get_num_threads'] = 'mkl_get_max_threads()'
else:
sub['blas_set_num_threads'] = ''
sub['blas_get_num_threads'] = '0'
else:
sub['omp_flags'] = ''
sub['omp_max_threads'] = '1'
sub['set_omp_threads'] = ''
sub['get_omp_threads'] = '0'
sub['set_blas_threads'] = ''
sub['get_blas_threads'] = '0'
sub['omp_get_max_threads'] = '1'
sub['omp_get_thread_num'] = '0'
sub['blas_set_num_threads'] = ''
sub['blas_get_num_threads'] = '0'
files = ['corr_gemm.c']
codes = [open(os.path.join(os.path.split(__file__)[0], f)).read()
......
......@@ -184,7 +184,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
}
// Create temporary columns
int max_threads = %(omp_max_threads)s;
int max_threads = %(omp_get_max_threads)s;
if (batchSize < max_threads) {
max_threads = batchSize;
}
......@@ -219,18 +219,16 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
char Trans = 'T';
PyArrayObject *output;
%(set_omp_threads)s(max_threads);
if (direction == 0) { // forward pass
output = top;
// valid correlation: im2col, then gemm
// Iterate over batch
int blas_threads_saved = %(get_blas_threads)s;
int blas_threads_saved = %(blas_get_num_threads)s;
// Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance.
%(set_blas_threads)s(1);
%(blas_set_num_threads)s(1);
%(omp_flags)s
for (int n = 0; n < batchSize; ++n) {
int tid = %(get_omp_threads)s;
int tid = %(omp_get_thread_num)s;
// First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
......@@ -245,7 +243,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_);
}
// Restore to previous blas threads
%(set_blas_threads)s(blas_threads_saved);
%(blas_set_num_threads)s(blas_threads_saved);
/*
// Original caffe code for comparison
......@@ -297,13 +295,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
// valid convolution: im2col, then gemm
// Iterate over batch
int blas_threads_saved = %(get_blas_threads)s;
int blas_threads_saved = %(blas_get_num_threads)s;
// Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance.
%(set_blas_threads)s(1);
%(blas_set_num_threads)s(1);
// OMP for batch-level paralization
%(omp_flags)s
for (int n = 0; n < batchSize; ++n) {
int tid = %(get_omp_threads)s;
int tid = %(omp_get_thread_num)s;
// First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
......@@ -322,7 +320,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
tid * weight_dim[1], &K_);
}
// Restore to previous blas threads
%(set_blas_threads)s(blas_threads_saved);
%(blas_set_num_threads)s(blas_threads_saved);
//aggregate weights
memset((%(float_type)s*)PyArray_DATA(weight), 0, M_ * K_*sizeof(%(float_type)s));
......@@ -375,13 +373,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
// full convolution: gemm, then col2im
// Iterate over batch
int blas_threads_saved = %(get_blas_threads)s;
int blas_threads_saved = %(blas_get_num_threads)s;
// Always forcing gemm to one thread when OpenMP is enalbed for best and stable performance.
%(set_blas_threads)s(1);
%(blas_set_num_threads)s(1);
%(omp_flags)s
for (int n = 0; n < batchSize; ++n) {
// gemm into columns
int tid = %(get_omp_threads)s;
int tid = %(omp_get_thread_num)s;
%(gemm)s(&NTrans, &Trans,
&N_, &K_, &M_,
&one,
......@@ -395,7 +393,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride);
}
// Restore to previous blas threads
%(set_blas_threads)s(blas_threads_saved);
%(blas_set_num_threads)s(blas_threads_saved);
/*
// Original caffe code for comparison
// Note that this code was translated from the Theano GPU code,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论