提交 b834babe authored 作者: Ciyong Chen's avatar Ciyong Chen

forcing gemm to use 1 thread in corrOP, and change PyArray_EMPTY to PyArray_ZEROS

上级 192a59d1
...@@ -62,6 +62,16 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -62,6 +62,16 @@ class BaseCorrMM(gof.OpenMPOp):
self.subsample = tuple(subsample) self.subsample = tuple(subsample)
self.filter_dilation = tuple(filter_dilation) self.filter_dilation = tuple(filter_dilation)
if not theano.config.blas.ldflags:
raise NotImplementedError("C code for corrMM* classes need a blas library.")
else:
if 'openblas' in theano.config.blas.ldflags:
self.blas_type = 'openblas'
elif 'mkl' in theano.config.blas.ldflags:
self.blas_type = 'mkl'
else:
self.blas_type = ''
@property @property
def pad(self): def pad(self):
if self.border_mode != 'valid': if self.border_mode != 'valid':
...@@ -95,6 +105,10 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -95,6 +105,10 @@ class BaseCorrMM(gof.OpenMPOp):
def c_headers(self): def c_headers(self):
headers = ['<stdio.h>'] headers = ['<stdio.h>']
headers += super(BaseCorrMM, self).c_headers() headers += super(BaseCorrMM, self).c_headers()
if self.blas_type == 'openblas':
headers += ['cblas.h']
if self.blas_type == 'mkl':
headers += ['mkl.h']
return headers return headers
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -119,6 +133,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -119,6 +133,7 @@ class BaseCorrMM(gof.OpenMPOp):
sub['float_typenum'] = 'NPY_DOUBLE' sub['float_typenum'] = 'NPY_DOUBLE'
sub['n_bytes'] = 8 sub['n_bytes'] = 8
sub['c_float_type'] = 'double' sub['c_float_type'] = 'double'
if self.openmp: if self.openmp:
sub['omp_flags'] = '#pragma omp parallel for schedule(static)' sub['omp_flags'] = '#pragma omp parallel for schedule(static)'
sub['omp_max_threads'] = 'omp_get_max_threads()' sub['omp_max_threads'] = 'omp_get_max_threads()'
...@@ -130,6 +145,13 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -130,6 +145,13 @@ class BaseCorrMM(gof.OpenMPOp):
sub['omp_set_threads'] = '' sub['omp_set_threads'] = ''
sub['omp_get_threads'] = 0 sub['omp_get_threads'] = 0
if self.blas_type == 'openblas':
sub['blas_flags'] = 'openblas_set_num_threads(1)'
elif self.blas_type == 'mkl':
sub['blas_flags'] = 'mkl_set_num_threads(1)'
else:
sub['blas_flags'] = ''
files = ['corr_gemm.c'] files = ['corr_gemm.c']
codes = [open(os.path.join(os.path.split(__file__)[0], f)).read() codes = [open(os.path.join(os.path.split(__file__)[0], f)).read()
for f in files] for f in files]
...@@ -173,8 +195,6 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -173,8 +195,6 @@ class BaseCorrMM(gof.OpenMPOp):
If self.border_mode == 'half', a variable giving the width of the If self.border_mode == 'half', a variable giving the width of the
filters for direction="backprop weights". Ignored otherwise. filters for direction="backprop weights". Ignored otherwise.
""" """
if not theano.config.blas.ldflags:
raise NotImplementedError("C code for CorrMM* classes need a blas library.")
dH, dW = self.subsample dH, dW = self.subsample
dilH, dilW = self.filter_dilation dilH, dilW = self.filter_dilation
if self.border_mode == "half": if self.border_mode == "half":
...@@ -340,7 +360,8 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -340,7 +360,8 @@ class BaseCorrMM(gof.OpenMPOp):
else { else {
typenum = PyArray_TYPE(bottom); typenum = PyArray_TYPE(bottom);
} }
%(out)s = (PyArrayObject*)PyArray_EMPTY(4, //Change to PyArray_ZEROS which is faster than PyArray_EMPTY.
%(out)s = (PyArrayObject*)PyArray_ZEROS(4,
out_dim, out_dim,
typenum, typenum,
0); 0);
......
...@@ -189,7 +189,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -189,7 +189,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
col_dim[0] = (npy_intp)max_threads; col_dim[0] = (npy_intp)max_threads;
col_dim[1] = (npy_intp)(nChannels * kW * kH); col_dim[1] = (npy_intp)(nChannels * kW * kH);
col_dim[2] = (npy_intp)(topHeight * topWidth); col_dim[2] = (npy_intp)(topHeight * topWidth);
//Change to PyArray_ZEROS which is faster than PyArray_EMPTY. //Change to PyArray_ZEROS which is faster than PyArray_EMPTY.
PyArrayObject* col = (PyArrayObject*)PyArray_ZEROS(3, PyArrayObject* col = (PyArrayObject*)PyArray_ZEROS(3,
col_dim, col_dim,
...@@ -230,6 +230,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -230,6 +230,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW, bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
// Second, gemm // Second, gemm
// Always forcing gemm to one thread here for best and stable performance.
%(blas_flags)s;
%(gemm)s(&NTrans, &NTrans, %(gemm)s(&NTrans, &NTrans,
&N_, &M_, &K_, &N_, &M_, &K_,
&one, &one,
...@@ -301,6 +303,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -301,6 +303,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.) // is faster than setting weight to all zeros before the loop.)
// Always forcing gemm to one thread here for best and stable performance.
%(blas_flags)s;
%(gemm)s(&Trans, &NTrans, %(gemm)s(&Trans, &NTrans,
&K_, &M_, &N_, &K_, &M_, &N_,
&one, &one,
...@@ -365,6 +369,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -365,6 +369,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
// gemm into columns // gemm into columns
int tid = %(omp_get_threads)s; int tid = %(omp_get_threads)s;
// Always forcing gemm to one thread here for best and stable performance.
%(blas_flags)s;
%(gemm)s(&NTrans, &Trans, %(gemm)s(&NTrans, &Trans,
&N_, &K_, &M_, &N_, &K_, &M_,
&one, &one,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论