提交 0fd05518 authored 作者: Ciyong Chen's avatar Ciyong Chen

update codes based on the comments

上级 1b9499e6
...@@ -12,8 +12,6 @@ from theano.tensor.nnet.abstract_conv import get_conv_output_shape ...@@ -12,8 +12,6 @@ from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas import ldflags, blas_header_version from theano.tensor.blas import ldflags, blas_header_version
from multiprocessing import cpu_count
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -122,13 +120,13 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -122,13 +120,13 @@ class BaseCorrMM(gof.OpenMPOp):
sub['n_bytes'] = 8 sub['n_bytes'] = 8
sub['c_float_type'] = 'double' sub['c_float_type'] = 'double'
if self.openmp: if self.openmp:
sub['cores'] = self.cores sub['omp_flags'] = '#pragma omp parallel for schedule(static)'
sub['omp_flags'] = '#pragma omp parallel for' sub['omp_max_threads'] = 'omp_get_max_threads()'
sub['omp_set_threads'] = 'omp_set_num_threads' sub['omp_set_threads'] = 'omp_set_num_threads'
sub['omp_get_threads'] = 'omp_get_thread_num()' sub['omp_get_threads'] = 'omp_get_thread_num()'
else: else:
sub['cores'] = 1
sub['omp_flags'] = '' sub['omp_flags'] = ''
sub['omp_max_threads'] = 1
sub['omp_set_threads'] = '' sub['omp_set_threads'] = ''
sub['omp_get_threads'] = 0 sub['omp_get_threads'] = 0
...@@ -342,7 +340,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -342,7 +340,7 @@ class BaseCorrMM(gof.OpenMPOp):
else { else {
typenum = PyArray_TYPE(bottom); typenum = PyArray_TYPE(bottom);
} }
%(out)s = (PyArrayObject*)PyArray_ZEROS(4, %(out)s = (PyArrayObject*)PyArray_EMPTY(4,
out_dim, out_dim,
typenum, typenum,
0); 0);
......
...@@ -184,10 +184,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -184,10 +184,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
} }
// Create temporary columns // Create temporary columns
int max_threads = %(omp_get_max_threads)s; const int max_threads = %(omp_max_threads)s < batchSize ? %(omp_max_threads)s : batchSize;
if (batchSize < max_threads) {
max_threads = batchSize;
}
npy_intp col_dim[3]; npy_intp col_dim[3];
col_dim[0] = (npy_intp)max_threads; col_dim[0] = (npy_intp)max_threads;
col_dim[1] = (npy_intp)(nChannels * kW * kH); col_dim[1] = (npy_intp)(nChannels * kW * kH);
...@@ -289,7 +287,6 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -289,7 +287,6 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
weight_dim[0], weight_dim[1]); weight_dim[0], weight_dim[1]);
return NULL; return NULL;
} }
local_weight = PyArray_GETCONTIGUOUS(local_weight);
// valid convolution: im2col, then gemm // valid convolution: im2col, then gemm
// Iterate over batch // Iterate over batch
...@@ -328,6 +325,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -328,6 +325,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
i * weight_dim[1] + j); i * weight_dim[1] + j);
} }
} }
Py_DECREF(local_weight);
/* /*
// Original caffe code for comparison // Original caffe code for comparison
// Note that this code was translated from the Theano GPU code, // Note that this code was translated from the Theano GPU code,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论