提交 79c1b8de authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

modify corrmm to support num_groups

上级 69338f63
...@@ -51,10 +51,11 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -51,10 +51,11 @@ class BaseCorrMM(gof.OpenMPOp):
('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')), # 2 ('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')), # 2
dH=int64, dW=int64, dH=int64, dW=int64,
dilH=int64, dilW=int64, dilH=int64, dilW=int64,
padH=int64, padW=int64) padH=int64, padW=int64,
num_groups=int64)
def __init__(self, border_mode="valid", subsample=(1, 1), def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1), openmp=None): filter_dilation=(1, 1), num_groups=1, openmp=None):
super(BaseCorrMM, self).__init__(openmp=openmp) super(BaseCorrMM, self).__init__(openmp=openmp)
if isinstance(border_mode, integer_types): if isinstance(border_mode, integer_types):
if border_mode < 0: if border_mode < 0:
...@@ -97,6 +98,9 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -97,6 +98,9 @@ class BaseCorrMM(gof.OpenMPOp):
if self._direction not in ["forward", "backprop weights", "backprop inputs"]: if self._direction not in ["forward", "backprop weights", "backprop inputs"]:
raise ValueError("_direction must be one of 'forward', " raise ValueError("_direction must be one of 'forward', "
"'backprop weights', 'backprop inputs'") "'backprop weights', 'backprop inputs'")
if num_groups < 1:
raise ValueError("Number of groups should be greater than 0")
self.num_groups = num_groups
@property @property
def pad(self): def pad(self):
...@@ -274,6 +278,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -274,6 +278,7 @@ class BaseCorrMM(gof.OpenMPOp):
int dilW = %(params)s->dilW; int dilW = %(params)s->dilW;
int padH = %(params)s->padH; int padH = %(params)s->padH;
int padW = %(params)s->padW; int padW = %(params)s->padW;
int numgroups = %(params)s->num_groups;
PyArrayObject * bottom = %(bottom)s; PyArrayObject * bottom = %(bottom)s;
PyArrayObject * weights = %(weights)s; PyArrayObject * weights = %(weights)s;
...@@ -386,7 +391,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -386,7 +391,7 @@ class BaseCorrMM(gof.OpenMPOp):
// output is weights: (num_filters, num_channels, height, width) // output is weights: (num_filters, num_channels, height, width)
// height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1 // height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1
out_dim[0] = (npy_intp)PyArray_DIMS(top)[1]; out_dim[0] = (npy_intp)PyArray_DIMS(top)[1];
out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1]; out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1] / numgroups;
out_dim[2] = (npy_intp)kH; // already inferred further above out_dim[2] = (npy_intp)kH; // already inferred further above
out_dim[3] = (npy_intp)kW; // how convenient out_dim[3] = (npy_intp)kW; // how convenient
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0) if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
...@@ -409,7 +414,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -409,7 +414,7 @@ class BaseCorrMM(gof.OpenMPOp):
// output is bottom: (batchsize, num_channels, height, width) // output is bottom: (batchsize, num_channels, height, width)
// height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad // height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
out_dim[0] = (npy_intp)PyArray_DIMS(top)[0]; out_dim[0] = (npy_intp)PyArray_DIMS(top)[0];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1]; out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1] * numgroups;
out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH); out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH);
out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW); out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW);
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0) if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
...@@ -465,7 +470,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -465,7 +470,7 @@ class BaseCorrMM(gof.OpenMPOp):
} }
// Call corrMM code // Call corrMM code
out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW, padH, padW); out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW, padH, padW, numgroups );
if (out2==NULL){ if (out2==NULL){
%(fail)s %(fail)s
} }
...@@ -541,11 +546,13 @@ class CorrMM(BaseCorrMM): ...@@ -541,11 +546,13 @@ class CorrMM(BaseCorrMM):
top, = grads top, = grads
d_bottom = CorrMM_gradInputs(self.border_mode, d_bottom = CorrMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(weights, top, self.filter_dilation,
self.num_groups)(weights, top,
bottom.shape[-2:]) bottom.shape[-2:])
d_weights = CorrMM_gradWeights(self.border_mode, d_weights = CorrMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, top, self.filter_dilation,
self.num_groups)(bottom, top,
weights.shape[-2:]) weights.shape[-2:])
return d_bottom, d_weights return d_bottom, d_weights
...@@ -632,11 +639,13 @@ class CorrMM_gradWeights(BaseCorrMM): ...@@ -632,11 +639,13 @@ class CorrMM_gradWeights(BaseCorrMM):
weights, = grads weights, = grads
d_bottom = CorrMM_gradInputs(self.border_mode, d_bottom = CorrMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(weights, top, self.filter_dilation,
self.num_groups)(weights, top,
bottom.shape[-2:]) bottom.shape[-2:])
d_top = CorrMM(self.border_mode, d_top = CorrMM(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
d_height_width = ((theano.gradient.DisconnectedType()(),) * 2 d_height_width = ((theano.gradient.DisconnectedType()(),) * 2
if len(inp) == 4 else ()) if len(inp) == 4 else ())
return (d_bottom, d_top) + d_height_width return (d_bottom, d_top) + d_height_width
...@@ -738,12 +747,14 @@ class CorrMM_gradInputs(BaseCorrMM): ...@@ -738,12 +747,14 @@ class CorrMM_gradInputs(BaseCorrMM):
bottom, = grads bottom, = grads
d_weights = CorrMM_gradWeights(self.border_mode, d_weights = CorrMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, self.filter_dilation,
self.num_groups)(bottom,
top, top,
weights.shape[-2:]) weights.shape[-2:])
d_top = CorrMM(self.border_mode, d_top = CorrMM(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
d_height_width = ((theano.gradient.DisconnectedType()(),) * d_height_width = ((theano.gradient.DisconnectedType()(),) *
2 if len(inp) == 4 else ()) 2 if len(inp) == 4 else ())
return (d_weights, d_top) + d_height_width return (d_weights, d_top) + d_height_width
......
...@@ -106,7 +106,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -106,7 +106,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
const int dilH = 1, const int dilH = 1,
const int dilW = 1, const int dilW = 1,
const int padH = 0, const int padH = 0,
const int padW = 0) const int padW = 0,
const int numgroups = 1)
{ {
if (PyArray_NDIM(bottom) != 4) if (PyArray_NDIM(bottom) != 4)
{ {
...@@ -155,7 +156,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -155,7 +156,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
const int nFilters = PyArray_DIMS(weight)[0]; const int nFilters = PyArray_DIMS(weight)[0];
const int kH = PyArray_DIMS(weight)[2]; const int kH = PyArray_DIMS(weight)[2];
const int kW = PyArray_DIMS(weight)[3]; const int kW = PyArray_DIMS(weight)[3];
if (nChannels != PyArray_DIMS(weight)[1]) { if (nChannels != (PyArray_DIMS(weight)[1] * numgroups)) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"CorrMM images and kernel must have the same stack size\n"); "CorrMM images and kernel must have the same stack size\n");
return NULL; return NULL;
...@@ -214,12 +215,16 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -214,12 +215,16 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
} }
// Define some useful variables // Define some useful variables
const int bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f; const int batch_bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f;
const int top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f; const int group_bottom_stride = (PyArray_STRIDES(bottom)[1] * nChannels / numgroups)/%(n_bytes)f;
const int K_ = col_dim[1]; const int batch_top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f;
const int group_top_stride = (PyArray_STRIDES(top)[1] * nFilters / numgroups)/%(n_bytes)f;
const int K_ = col_dim[1] / numgroups;
const int N_ = col_dim[2]; const int N_ = col_dim[2];
const int col_stride = (K_ * N_); const int col_stride = (K_ * N_ * numgroups);
const int M_ = nFilters; const int group_col_stride = (K_ * N_);
const int group_weight_stride = (PyArray_STRIDES(weight)[0] * nFilters / numgroups)/%(n_bytes)f;
const int M_ = nFilters / numgroups;
const %(c_float_type)s one = 1.0; const %(c_float_type)s one = 1.0;
const %(c_float_type)s zero = 0.0; const %(c_float_type)s zero = 0.0;
char NTrans = 'N'; char NTrans = 'N';
...@@ -253,17 +258,19 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -253,17 +258,19 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
// First, im2col // First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight, im2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride, nChannels,
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW, bottomHeight,bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
for ( int g = 0; g < numgroups; ++g){
// Second, gemm // Second, gemm
%(gemm)s(&NTrans, &NTrans, %(gemm)s(&NTrans, &NTrans,
&N_, &M_, &K_, &N_, &M_, &K_,
&one, &one,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride, &N_, (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_, (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
&zero, &zero,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_); (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_);
}
} }
// Restore to previous blas threads // Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved); %(blas_set_num_threads)s(blas_threads_saved);
...@@ -304,7 +311,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -304,7 +311,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
output = weight; output = weight;
npy_intp weight_dim[2]; npy_intp weight_dim[2];
weight_dim[0] = (npy_intp)max_threads; weight_dim[0] = (npy_intp)max_threads;
weight_dim[1] = (npy_intp)(M_ * K_); weight_dim[1] = (npy_intp)(M_ * K_ * numgroups);
PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2, PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2,
weight_dim, PyArray_TYPE(weight), 0); weight_dim, PyArray_TYPE(weight), 0);
...@@ -326,9 +333,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -326,9 +333,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
// First, im2col // First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight, im2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride,
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW, nChannels, bottomHeight,bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
for(int g = 0; g < numgroups; ++g){
// Second, gemm // Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
...@@ -336,12 +344,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -336,12 +344,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
%(gemm)s(&Trans, &NTrans, %(gemm)s(&Trans, &NTrans,
&K_, &M_, &N_, &K_, &M_, &N_,
&one, &one,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_, (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_, (%(float_type)s*)PyArray_DATA(top) + g * group_top_stride + n * batch_top_stride, &N_,
(n == 0) ? &zero : &one, (n == 0) ? &zero : &one,
(%(float_type)s*)PyArray_DATA(local_weight) + (%(float_type)s*)PyArray_DATA(local_weight) + g * group_weight_stride +
tid * weight_dim[1], &K_); tid * weight_dim[1], &K_);
} }
}
// Restore to previous blas threads // Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved); %(blas_set_num_threads)s(blas_threads_saved);
...@@ -401,19 +410,21 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -401,19 +410,21 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
%(blas_set_num_threads)s(1); %(blas_set_num_threads)s(1);
%(omp_flags)s %(omp_flags)s
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
// gemm into columns
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
for ( int g = 0;g < numgroups; ++g){
// gemm into columns
%(gemm)s(&NTrans, &Trans, %(gemm)s(&NTrans, &Trans,
&N_, &K_, &M_, &N_, &K_, &M_,
&one, &one,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_, (%(float_type)s*)PyArray_DATA(top) + g * group_top_stride + n * batch_top_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_, (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
&zero, &zero,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_); (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_);
}
// col2im back to the data // col2im back to the data
col2im((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels, bottomHeight, bottomWidth, col2im((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels, bottomHeight, bottomWidth,
kH, kW, dilH, dilW, padH, padW, kH, kW, dilH, dilW, padH, padW,
dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride); dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride);
} }
// Restore to previous blas threads // Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved); %(blas_set_num_threads)s(blas_threads_saved);
......
...@@ -88,7 +88,9 @@ def local_abstractconv_gemm(node): ...@@ -88,7 +88,9 @@ def local_abstractconv_gemm(node):
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
rval = CorrMM(border_mode=node.op.border_mode, rval = CorrMM(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(img, kern) filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(img, kern)
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(node.outputs[0], rval)
return [rval] return [rval]
...@@ -133,7 +135,8 @@ def local_abstractconv_gradweight_gemm(node): ...@@ -133,7 +135,8 @@ def local_abstractconv_gradweight_gemm(node):
rval = CorrMM_gradWeights(border_mode=node.op.border_mode, rval = CorrMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(img, topgrad, shape) filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(img, topgrad, shape)
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(node.outputs[0], rval)
# need to flip the kernel if necessary # need to flip the kernel if necessary
...@@ -190,7 +193,8 @@ def local_abstractconv_gradinputs_gemm(node): ...@@ -190,7 +193,8 @@ def local_abstractconv_gradinputs_gemm(node):
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
rval = CorrMM_gradInputs(border_mode=node.op.border_mode, rval = CorrMM_gradInputs(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(kern, topgrad, filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(kern, topgrad,
shape) shape)
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(node.outputs[0], rval)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论