提交 10bac9b1 authored 作者: affanv14's avatar affanv14

modify Corr3dMM to support num_groups

上级 178ae035
......@@ -127,7 +127,8 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
const int dilD = 1,
const int padH = 0,
const int padW = 0,
const int padD = 0)
const int padD = 0,
const int numgroups=1)
{
if (PyArray_NDIM(bottom) != 5)
{
......@@ -178,7 +179,7 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
const int kH = PyArray_DIMS(weight)[2];
const int kW = PyArray_DIMS(weight)[3];
const int kD = PyArray_DIMS(weight)[4];
if (nChannels != PyArray_DIMS(weight)[1]) {
if (nChannels != PyArray_DIMS(weight)[1] * numgroups) {
PyErr_SetString(PyExc_ValueError,
"Corr3dMM images and kernel must have the same stack size\n");
return NULL;
......@@ -210,7 +211,7 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
" weight shape: %%d %%d %%d %%d %%d\n"
" top shape: %%ld %%ld %%ld %%ld %%ld (expected %%d %%d %%d %%d %%d)\n",
batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth,
nFilters, nChannels, kH, kW, kD,
nFilters, nChannels / numgroups, kH, kW, kD,
PyArray_DIMS(top)[0], PyArray_DIMS(top)[1],
PyArray_DIMS(top)[2], PyArray_DIMS(top)[3], PyArray_DIMS(top)[4],
batchSize, nFilters, topHeight, topWidth, topDepth);
......@@ -241,12 +242,16 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
}
// Define some useful variables
const int bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f;
const int top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f;
const int K_ = col_dim[1];
const int batch_bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f;
const int group_bottom_stride = (PyArray_STRIDES(bottom)[1] * nChannels / numgroups)/%(n_bytes)f;
const int batch_top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f;
const int group_top_stride = (PyArray_STRIDES(top)[1] * nFilters / numgroups)/%(n_bytes)f;
const int K_ = col_dim[1] / numgroups;
const int N_ = col_dim[2];
const int col_stride = (K_ * N_);
const int M_ = nFilters;
const int col_stride = (K_ * N_ * numgroups);
const int group_col_stride = (K_ * N_);
const int group_weight_stride = (PyArray_STRIDES(weight)[0] * nFilters / numgroups)/%(n_bytes)f;
const int M_ = nFilters / numgroups;
const %(c_float_type)s one = 1.0;
const %(c_float_type)s zero = 0.0;
char NTrans = 'N';
......@@ -280,18 +285,21 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_thread_num)s;
// First, im3d2col
im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels,
bottomHeight, bottomWidth, bottomDepth,
im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride,
nChannels, bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
// Second, gemm
%(gemm)s(&NTrans, &NTrans,
&N_, &M_, &K_,
&one,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_,
&zero,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_);
for ( int g = 0; g < numgroups; ++g){
// Second, gemm
%(gemm)s(&NTrans, &NTrans,
&N_, &M_, &K_,
&one,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride + g * group_col_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
&zero,
(%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_);
}
}
// Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved);
......@@ -300,7 +308,7 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
output = weight;
npy_intp weight_dim[2];
weight_dim[0] = (npy_intp)max_threads;
weight_dim[1] = (npy_intp)(M_ * K_);
weight_dim[1] = (npy_intp)(M_ * K_ * numgroups);
PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2,
weight_dim, PyArray_TYPE(weight), 0);
......@@ -322,22 +330,25 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_thread_num)s;
// First, im2col
im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels,
bottomHeight, bottomWidth, bottomDepth,
im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride,
nChannels, bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
// Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.)
%(gemm)s(&Trans, &NTrans,
&K_, &M_, &N_,
&one,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_,
(n == 0) ? &zero : &one,
(%(float_type)s*)PyArray_DATA(local_weight) +
tid * weight_dim[1], &K_);
for ( int g = 0; g < numgroups; ++g){
// Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.)
%(gemm)s(&Trans, &NTrans,
&K_, &M_, &N_,
&one,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_,
(%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_,
(n == 0) ? &zero : &one,
(%(float_type)s*)PyArray_DATA(local_weight) + g * group_weight_stride +
tid * weight_dim[1], &K_);
}
}
// Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved);
......@@ -370,20 +381,23 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
%(blas_set_num_threads)s(1);
%(omp_flags)s
for (int n = 0; n < batchSize; ++n) {
// gemm into columns
int tid = %(omp_get_thread_num)s;
%(gemm)s(&NTrans, &Trans,
&N_, &K_, &M_,
&one,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_,
&zero,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_);
for ( int g = 0; g < numgroups; ++g){
// gemm into columns
%(gemm)s(&NTrans, &Trans,
&N_, &K_, &M_,
&one,
(%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
&zero,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_);
}
// col2im back to the data
col2im3d((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels,
bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD,
(%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride);
(%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride);
}
// Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved);
......
......@@ -51,10 +51,11 @@ class BaseCorr3dMM(gof.OpenMPOp):
('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')), # 2
dH=int64, dW=int64, dD=int64,
dilH=int64, dilW=int64, dilD=int64,
padH=int64, padW=int64, padD=int64)
padH=int64, padW=int64, padD=int64,
num_groups=int64)
def __init__(self, border_mode="valid", subsample=(1, 1, 1),
filter_dilation=(1, 1, 1), openmp=None):
filter_dilation=(1, 1, 1), openmp=None, num_groups=1):
super(BaseCorr3dMM, self).__init__(openmp=openmp)
if isinstance(border_mode, integer_types):
if border_mode < 0:
......@@ -82,6 +83,9 @@ class BaseCorr3dMM(gof.OpenMPOp):
raise ValueError("filter_dilation must have three elements")
self.subsample = tuple(subsample)
self.filter_dilation = tuple(filter_dilation)
if num_groups < 1:
raise ValueError("Number of groups should be greater than 0")
self.num_groups = num_groups
if not theano.config.blas.ldflags:
# Theano will use a NumPy C implementation of [sd]gemm_ instead.
......@@ -127,11 +131,12 @@ class BaseCorr3dMM(gof.OpenMPOp):
padD = property(lambda self: self.pad[2])
def __str__(self):
return '%s{%s, %s, %s}' % (
return '%s{%s, %s, %s, %s}' % (
self.__class__.__name__,
self.border_mode,
str(self.subsample),
str(self.filter_dilation))
str(self.filter_dilation),
str(self.num_groups))
@staticmethod
def as_common_dtype(in1, in2):
......@@ -293,6 +298,7 @@ class BaseCorr3dMM(gof.OpenMPOp):
int padH = %(params)s->padH;
int padW = %(params)s->padW;
int padD = %(params)s->padD;
int numgroups = %(params)s->num_groups;
PyArrayObject * bottom = %(bottom)s;
PyArrayObject * weights = %(weights)s;
......@@ -428,7 +434,7 @@ class BaseCorr3dMM(gof.OpenMPOp):
// output is weights: (num_filters, num_channels, height, width, depth)
// height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1
out_dim[0] = (npy_intp)PyArray_DIMS(top)[1];
out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1];
out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1] / numgroups;
out_dim[2] = (npy_intp)kH; // already inferred further above
out_dim[3] = (npy_intp)kW; // how convenient
out_dim[4] = (npy_intp)kD;
......@@ -454,7 +460,7 @@ class BaseCorr3dMM(gof.OpenMPOp):
// output is bottom: (batchsize, num_channels, height, width, depth)
// height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
out_dim[0] = (npy_intp)PyArray_DIMS(top)[0];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1] * numgroups;
out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH);
out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW);
out_dim[4] = (npy_intp)((%(depth)s != -1) ? %(depth)s : (PyArray_DIMS(top)[4] - 1) * dD + (PyArray_DIMS(weights)[4]-1)*dilD + 1 - 2*padD);
......@@ -516,7 +522,8 @@ class BaseCorr3dMM(gof.OpenMPOp):
// Call corr3dMM code
out2 = corr3dMM(%(bottom)s, %(weights)s, %(top)s, direction,
dH, dW, dD, dilH, dilW, dilD, padH, padW, padD);
dH, dW, dD, dilH, dilW, dilD, padH, padW, padD,
numgroups);
if (out2==NULL){
%(fail)s
}
......@@ -592,12 +599,14 @@ class Corr3dMM(BaseCorr3dMM):
top, = grads
d_bottom = Corr3dMM_gradInputs(self.border_mode,
self.subsample,
self.filter_dilation)(weights, top,
bottom.shape[-3:])
self.filter_dilation,
num_groups=self.num_groups)(weights, top,
bottom.shape[-3:])
d_weights = Corr3dMM_gradWeights(self.border_mode,
self.subsample,
self.filter_dilation)(bottom, top,
weights.shape[-3:])
self.filter_dilation,
num_groups=self.num_groups)(bottom, top,
weights.shape[-3:])
return d_bottom, d_weights
......@@ -691,11 +700,13 @@ class Corr3dMM_gradWeights(BaseCorr3dMM):
weights, = grads
d_bottom = Corr3dMM_gradInputs(self.border_mode,
self.subsample,
self.filter_dilation)(weights, top,
bottom.shape[-3:])
self.filter_dilation,
num_groups=self.num_groups)(weights, top,
bottom.shape[-3:])
d_top = Corr3dMM(self.border_mode,
self.subsample,
self.filter_dilation)(bottom, weights)
self.filter_dilation,
num_groups=self.num_groups)(bottom, weights)
d_height_width_depth = ((theano.gradient.DisconnectedType()(),) * 3
if len(inp) == 5 else ())
return (d_bottom, d_top) + d_height_width_depth
......@@ -738,8 +749,12 @@ class Corr3dMM_gradInputs(BaseCorr3dMM):
as_tensor_variable(shape[1]).astype('int64'),
as_tensor_variable(shape[2]).astype('int64')]
broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1],
False, False, False]
if self.num_groups > 1:
broadcastable = [topgrad.type.broadcastable[0], False,
False, False, False]
else:
broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1],
False, False, False]
dtype = kern.type.dtype
return Apply(self, [kern, topgrad] + height_width_depth,
[TensorType(dtype, broadcastable)()])
......@@ -807,12 +822,14 @@ class Corr3dMM_gradInputs(BaseCorr3dMM):
bottom, = grads
d_weights = Corr3dMM_gradWeights(self.border_mode,
self.subsample,
self.filter_dilation)(bottom,
top,
weights.shape[-3:])
self.filter_dilation,
num_groups=self.num_groups)(bottom,
top,
weights.shape[-3:])
d_top = Corr3dMM(self.border_mode,
self.subsample,
self.filter_dilation)(bottom, weights)
self.filter_dilation,
num_groups=self.num_groups)(bottom, weights)
d_height_width_depth = ((theano.gradient.DisconnectedType()(),) * 3
if len(inp) == 5 else ())
return (d_weights, d_top) + d_height_width_depth
......
......@@ -114,7 +114,8 @@ def local_abstractconv3d_gemm(node):
kern = kern[:, :, ::-1, ::-1, ::-1]
rval = Corr3dMM(border_mode=node.op.border_mode,
subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(img, kern)
filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(img, kern)
copy_stack_trace(node.outputs[0], rval)
return [rval]
......@@ -163,7 +164,8 @@ def local_abstractconv3d_gradweight_gemm(node):
rval = Corr3dMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(img, topgrad, shape)
filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(img, topgrad, shape)
copy_stack_trace(node.outputs[0], rval)
# need to flip the kernel if necessary
......@@ -219,8 +221,9 @@ def local_abstractconv3d_gradinputs_gemm(node):
kern = kern[:, :, ::-1, ::-1, ::-1]
rval = Corr3dMM_gradInputs(border_mode=node.op.border_mode,
subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(kern, topgrad,
shape)
filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(kern, topgrad,
shape)
copy_stack_trace(node.outputs[0], rval)
return [rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论