提交 38e03e0f authored 作者: Vikram's avatar Vikram

Started asymmetric padding

上级 9592125c
...@@ -31,23 +31,23 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -31,23 +31,23 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
void im2col(const %(float_type)s* data_im, const int channels, void im2col(const %(float_type)s* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w, const int height, const int width, const int kernel_h, const int kernel_w,
const int dilation_h, const int dilation_w, const int dilation_h, const int dilation_w,
const int pad_h, const int pad_w, const int pad_hl, const int pad_hr, const int pad_wl, const int pad_wr,
const int stride_h, const int stride_w, const int stride_h, const int stride_w,
%(float_type)s* data_col) { %(float_type)s* data_col) {
// Implicit dilated kernel size // Implicit dilated kernel size
int dil_kernel_h = (kernel_h - 1) * dilation_h + 1; int dil_kernel_h = (kernel_h - 1) * dilation_h + 1;
int dil_kernel_w = (kernel_w - 1) * dilation_w + 1; int dil_kernel_w = (kernel_w - 1) * dilation_w + 1;
int height_col = (height + 2 * pad_h - dil_kernel_h) / stride_h + 1; int height_col = (height + pad_hl + pad_hr - dil_kernel_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - dil_kernel_w) / stride_w + 1; int width_col = (width + pad_wl + pad_wr - dil_kernel_w) / stride_w + 1;
int channels_col = channels * kernel_h * kernel_w; int channels_col = channels * kernel_h * kernel_w;
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c %% kernel_w; int w_offset = c %% kernel_w;
int h_offset = (c / kernel_w) %% kernel_h; int h_offset = (c / kernel_w) %% kernel_h;
int c_im = c / kernel_h / kernel_w; int c_im = c / kernel_h / kernel_w;
for (int h = 0; h < height_col; ++h) { for (int h = 0; h < height_col; ++h) {
int h_pad = h * stride_h - pad_h + h_offset * dilation_h; int h_pad = h * stride_h - pad_hl + h_offset * dilation_h;
for (int w = 0; w < width_col; ++w) { for (int w = 0; w < width_col; ++w) {
int w_pad = w * stride_w - pad_w + w_offset * dilation_w; int w_pad = w * stride_w - pad_wl + w_offset * dilation_w;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
data_col[(npy_intp)(c * height_col + h) * width_col + w] = data_col[(npy_intp)(c * height_col + h) * width_col + w] =
data_im[(npy_intp)(c_im * height + h_pad) * width + w_pad]; data_im[(npy_intp)(c_im * height + h_pad) * width + w_pad];
...@@ -64,13 +64,14 @@ void im2col(const %(float_type)s* data_im, const int channels, ...@@ -64,13 +64,14 @@ void im2col(const %(float_type)s* data_im, const int channels,
void col2im(const %(float_type)s* data_col, const int channels, void col2im(const %(float_type)s* data_col, const int channels,
const int height, const int width, const int patch_h, const int patch_w, const int height, const int width, const int patch_h, const int patch_w,
const int dilation_h, const int dilation_w, const int dilation_h, const int dilation_w,
const int pad_h, const int pad_w, const int stride_h, const int pad_hl, const int pad_hr, const int pad_wl, const int pad_wr,
const int stride_w, %(float_type)s* data_im) { const int stride_h, const int stride_w,
%(float_type)s* data_im) {
// Implicit dilated patch // Implicit dilated patch
int dil_patch_h = (patch_h - 1) * dilation_h + 1; int dil_patch_h = (patch_h - 1) * dilation_h + 1;
int dil_patch_w = (patch_w - 1) * dilation_w + 1; int dil_patch_w = (patch_w - 1) * dilation_w + 1;
int height_col = (height + 2 * pad_h - dil_patch_h) / stride_h + 1; int height_col = (height + pad_hl + pad_hr - dil_patch_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - dil_patch_w) / stride_w + 1; int width_col = (width + pad_wl + pad_wr - dil_patch_w) / stride_w + 1;
int num_kernels = channels * height * width; int num_kernels = channels * height * width;
int channels_col = channels * patch_h * patch_w; int channels_col = channels * patch_h * patch_w;
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
...@@ -78,9 +79,9 @@ void col2im(const %(float_type)s* data_col, const int channels, ...@@ -78,9 +79,9 @@ void col2im(const %(float_type)s* data_col, const int channels,
int h_offset = (c / patch_w) %% patch_h; int h_offset = (c / patch_w) %% patch_h;
int c_im = c / patch_h / patch_w; int c_im = c / patch_h / patch_w;
for (int h = 0; h < height_col; ++h) { for (int h = 0; h < height_col; ++h) {
int h_pad = h * stride_h - pad_h + h_offset * dilation_h; int h_pad = h * stride_h - pad_hl + h_offset * dilation_h;
for (int w = 0; w < width_col; ++w) { for (int w = 0; w < width_col; ++w) {
int w_pad = w * stride_w - pad_w + w_offset * dilation_w; int w_pad = w * stride_w - pad_wl + w_offset * dilation_w;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
data_im[(npy_intp)(c_im * height + h_pad) * width + w_pad] += data_im[(npy_intp)(c_im * height + h_pad) * width + w_pad] +=
data_col[(npy_intp)(c * height_col + h) * width_col + w]; data_col[(npy_intp)(c * height_col + h) * width_col + w];
...@@ -105,8 +106,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -105,8 +106,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
const int dW = 1, const int dW = 1,
const int dilH = 1, const int dilH = 1,
const int dilW = 1, const int dilW = 1,
const int padH = 0, const int padH_l = 0,
const int padW = 0, const int padH_r = 0,
const int padW_l = 0,
const int padW_r = 0,
const int numgroups = 1, const int numgroups = 1,
const int unshared = 0) const int unshared = 0)
{ {
...@@ -172,8 +175,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -172,8 +175,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
const int dil_kH = (kH - 1) * dilH + 1; const int dil_kH = (kH - 1) * dilH + 1;
const int dil_kW = (kW - 1) * dilW + 1; const int dil_kW = (kW - 1) * dilW + 1;
// top: (batchSize, nFilters, topHeight, topWidth) // top: (batchSize, nFilters, topHeight, topWidth)
const int topHeightNoDH = (bottomHeight + 2*padH - dil_kH); const int topHeightNoDH = (bottomHeight + padH_l + padH_r - dil_kH);
const int topWidthNoDW = (bottomWidth + 2*padW - dil_kW); const int topWidthNoDW = (bottomWidth + padW_l + padW_r - dil_kW);
// the above values might be negative so we need to use Python-like // the above values might be negative so we need to use Python-like
// flooring integer division to be compatible with get_conv_output. // flooring integer division to be compatible with get_conv_output.
// note: this macro implements Python's // for negative x only // note: this macro implements Python's // for negative x only
...@@ -303,7 +306,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -303,7 +306,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
// First, im2col // First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride, nChannels, im2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride, nChannels,
bottomHeight,bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW, bottomHeight,bottomWidth, kH, kW, dilH, dilW, padH_l, padH_r, padW_l, padW_r, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
// Second, gemm // Second, gemm
if (unshared) { if (unshared) {
...@@ -396,7 +399,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -396,7 +399,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
// First, im2col // First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride, im2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride,
nChannels, bottomHeight,bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW, nChannels, bottomHeight,bottomWidth, kH, kW, dilH, dilW, padH_l, padH_r, padW_l, padW_r, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
// Second, gemm // Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
...@@ -519,7 +522,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -519,7 +522,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
} }
// col2im back to the data // col2im back to the data
col2im((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels, bottomHeight, bottomWidth, col2im((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels, bottomHeight, bottomWidth,
kH, kW, dilH, dilW, padH, padW, kH, kW, dilH, dilW, padH_l, padH_r, padW_l, padW_r,
dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride); dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride);
} }
// Restore to previous blas threads // Restore to previous blas threads
......
...@@ -66,20 +66,29 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -66,20 +66,29 @@ class BaseCorrMM(gof.OpenMPOp):
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be a ' 'invalid border_mode {}, which must be a '
'non-negative integer'.format(border_mode)) 'non-negative integer'.format(border_mode))
border_mode = (border_mode, border_mode) border_mode = ((border_mode, border_mode),) * 2
if isinstance(border_mode, tuple): elif isinstance(border_mode, tuple):
if len(border_mode) != 2 or border_mode[0] < 0 or border_mode[1] < 0: if len(border_mode) != 2:
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be a ' 'invalid border_mode {} which must be a '
'pair of non-negative integers'.format(border_mode)) 'tuple of length 2'.format(border_mode))
pad_h, pad_w = map(int, border_mode) border = ()
border_mode = (pad_h, pad_w) for mode in border_mode:
if not ((isinstance(border_mode, tuple) and min(border_mode) >= 0) or if isinstance(mode, integer_types) and mode > 0:
border_mode in ('valid', 'full', 'half')): border += ((mode, mode),)
elif isinstance(mode, tuple) and len(mode) == 2 and \
min(mode) >= 0:
border = ((mode[0], mode[1]),)
else:
raise ValueError(
'invalid border mode {}. The tuple can only contain '
'integers or tuples of length 2'.format(border_mode))
border_mode = border
elif border_mode not in ('valid', 'full', 'half'):
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be either ' 'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a pair of' '"valid", "full", "half", an integer or a tuple '
' integers'.format(border_mode)) 'of length 2'.format(border_mode))
self.border_mode = border_mode self.border_mode = border_mode
if len(subsample) != 2: if len(subsample) != 2:
raise ValueError("subsample must have two elements") raise ValueError("subsample must have two elements")
...@@ -110,14 +119,14 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -110,14 +119,14 @@ class BaseCorrMM(gof.OpenMPOp):
@property @property
def pad(self): def pad(self):
if self.border_mode == "half": if self.border_mode == "half":
return (-1, -1) return ((-1, -1),) * 2
elif self.border_mode == "full": elif self.border_mode == "full":
return (-2, -2) return ((-2, -2),) * 2
elif isinstance(self.border_mode, tuple): elif isinstance(self.border_mode, tuple):
return self.border_mode return self.border_mode
else: else:
assert self.border_mode == "valid" assert self.border_mode == "valid"
return (0, 0) return ((0, 0),) * 2
# Direction should be converted to real enum value, # Direction should be converted to real enum value,
# as it is compared to integer later in c_code_helper(). # as it is compared to integer later in c_code_helper().
...@@ -129,8 +138,10 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -129,8 +138,10 @@ class BaseCorrMM(gof.OpenMPOp):
dilH = property(lambda self: self.filter_dilation[0]) dilH = property(lambda self: self.filter_dilation[0])
dilW = property(lambda self: self.filter_dilation[1]) dilW = property(lambda self: self.filter_dilation[1])
padH = property(lambda self: self.pad[0]) padH_l = property(lambda self: self.pad[0][0])
padW = property(lambda self: self.pad[1]) padH_r = property(lambda self: self.pad[0][1])
padW_l = property(lambda self: self.pad[1][0])
padW_r = property(lambda self: self.pad[1][1])
def __str__(self): def __str__(self):
return '%s{%s, %s, %s, %s %s}' % ( return '%s{%s, %s, %s, %s %s}' % (
...@@ -271,13 +282,13 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -271,13 +282,13 @@ class BaseCorrMM(gof.OpenMPOp):
if height: if height:
height = '(*(npy_int64 *)(PyArray_DATA(%s)))' % height height = '(*(npy_int64 *)(PyArray_DATA(%s)))' % height
else: else:
if ((self.direction != 0) and (self.dH != 1)) or ((self.direction == 1) and (self.padH == -1)): if ((self.direction != 0) and (self.dH != 1)) or ((self.direction == 1) and (self.padH_l == -1)):
raise ValueError("height must be given for backprop with vertical sampling or border_mode='half'") raise ValueError("height must be given for backprop with vertical sampling or border_mode='half'")
height = '-1' height = '-1'
if width: if width:
width = '(*(npy_int64 *)(PyArray_DATA(%s)))' % width width = '(*(npy_int64 *)(PyArray_DATA(%s)))' % width
else: else:
if ((self.direction != 0) and (self.dW != 1)) or ((self.direction == 1) and (self.padW == -1)): if ((self.direction != 0) and (self.dW != 1)) or ((self.direction == 1) and (self.padW_l == -1)):
raise ValueError("width must be given for backprop with horizontal sampling or border_mode='half'") raise ValueError("width must be given for backprop with horizontal sampling or border_mode='half'")
width = '-1' width = '-1'
...@@ -290,8 +301,10 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -290,8 +301,10 @@ class BaseCorrMM(gof.OpenMPOp):
int dW = %(params)s->dW; int dW = %(params)s->dW;
int dilH = %(params)s->dilH; int dilH = %(params)s->dilH;
int dilW = %(params)s->dilW; int dilW = %(params)s->dilW;
int padH = %(params)s->padH; int padH_l = %(params)s->padH_l;
int padW = %(params)s->padW; int padH_r = %(params)s->padH_r;
int padW_l = %(params)s->padW_l;
int padW_r = %(params)s->padW_r;
int numgroups = %(params)s->num_groups; int numgroups = %(params)s->num_groups;
int unshared = %(params)s->unshared; int unshared = %(params)s->unshared;
...@@ -340,7 +353,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -340,7 +353,7 @@ class BaseCorrMM(gof.OpenMPOp):
} }
else { else {
// explicit padding, we can infer the kernel height // explicit padding, we can infer the kernel height
kH = (PyArray_DIMS(bottom)[2] + 2*padH - (PyArray_DIMS(top)[2] - 1) * dH - 1) / dilH +1; kH = (PyArray_DIMS(bottom)[2] + padH_l + padH_r - (PyArray_DIMS(top)[2] - 1) * dH - 1) / dilH +1;
} }
if (%(width)s != -1) { if (%(width)s != -1) {
// kernel width is specified (perhaps horizontal subsampling or half padding) // kernel width is specified (perhaps horizontal subsampling or half padding)
...@@ -350,7 +363,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -350,7 +363,7 @@ class BaseCorrMM(gof.OpenMPOp):
kW = (2 - PyArray_DIMS(bottom)[3] + (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1; kW = (2 - PyArray_DIMS(bottom)[3] + (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
} }
else { else {
kW = (PyArray_DIMS(bottom)[3] + 2*padW - (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1; kW = (PyArray_DIMS(bottom)[3] + padW_l + padW_r - (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
} }
} }
...@@ -386,11 +399,11 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -386,11 +399,11 @@ class BaseCorrMM(gof.OpenMPOp):
switch(direction) { switch(direction) {
case 0: // forward pass case 0: // forward pass
// output is top: (batchsize, num_filters, height, width) // output is top: (batchsize, num_filters, height, width)
// height and width: top = (bottom + 2*pad - ((weight-1)*dil + 1)) / sample + 1 // height and width: top = (bottom + pad_l + pad_r - ((weight-1)*dil + 1)) / sample + 1
out_dim[0] = (npy_intp)PyArray_DIMS(bottom)[0]; out_dim[0] = (npy_intp)PyArray_DIMS(bottom)[0];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[0]; out_dim[1] = (npy_intp)PyArray_DIMS(weights)[0];
out_dim[2] = (npy_intp)((PyArray_DIMS(bottom)[2] + 2*padH - ((PyArray_DIMS(weights)[wdim-2]-1)*dilH + 1)) / dH + 1); out_dim[2] = (npy_intp)((PyArray_DIMS(bottom)[2] + padH_l + padH_r - ((PyArray_DIMS(weights)[wdim-2]-1)*dilH + 1)) / dH + 1);
out_dim[3] = (npy_intp)((PyArray_DIMS(bottom)[3] + 2*padW - ((PyArray_DIMS(weights)[wdim-1]-1)*dilW + 1)) / dW + 1); out_dim[3] = (npy_intp)((PyArray_DIMS(bottom)[3] + padW_l + padW_r - ((PyArray_DIMS(weights)[wdim-1]-1)*dilW + 1)) / dW + 1);
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0) if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
{ {
if (unshared) { if (unshared) {
...@@ -564,7 +577,8 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -564,7 +577,8 @@ class BaseCorrMM(gof.OpenMPOp):
} }
// Call corrMM code // Call corrMM code
out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW, padH, padW, numgroups, unshared); out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW,
padH_l, padH_r, padW_l, padW_r, numgroups, unshared);
if (out2==NULL){ if (out2==NULL){
%(fail)s %(fail)s
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论