提交 5b220efc authored 作者: Li's avatar Li

added support for padding

上级 f120de51
......@@ -19,7 +19,7 @@ def max_pool2D(*args, **kwargs):
return max_pool_2d(*args, **kwargs)
def max_pool_2d(input, ds, ignore_border=False, st=None):
def max_pool_2d(input, ds, ignore_border=False, st=None, padding=(0,0)):
"""
Takes as input a N-D tensor, where N >= 2. It downscales the input image by
the specified factor, by keeping only the maximum value of non-overlapping
......@@ -62,7 +62,7 @@ def max_pool_2d(input, ds, ignore_border=False, st=None):
input_4D = tensor.reshape(input, new_shape, ndim=4)
# downsample mini-batch of images
op = DownsampleFactorMax(ds, ignore_border, st=st)
op = DownsampleFactorMax(ds, ignore_border, st=st, padding=padding)
output = op(input_4D)
# restore to original shape
......@@ -77,10 +77,10 @@ class DownsampleFactorMax(Op):
regions.
"""
__props__ = ('ds', 'ignore_border', 'st')
__props__ = ('ds', 'ignore_border', 'st', 'padding')
@staticmethod
def out_shape(imgshape, ds, ignore_border=False, st=None):
def out_shape(imgshape, ds, ignore_border=False, st=None, padding=(0,0)):
"""Return the shape of the output from this op, for input of given
shape and flags.
......@@ -113,7 +113,9 @@ class DownsampleFactorMax(Op):
if st is None:
st = ds
r, c = imgshape[-2:]
r += padding[0] * 2
c += padding[1] * 2
if ignore_border:
out_r = (r - ds[0]) // st[0] + 1
out_c = (c - ds[1]) // st[1] + 1
......@@ -149,7 +151,7 @@ class DownsampleFactorMax(Op):
rval = list(imgshape[:-2]) + [nr, nc]
return rval
def __init__(self, ds, ignore_border=False, st=None):
def __init__(self, ds, ignore_border=False, st=None, padding=(0,0)):
"""
:param ds: downsample factor over rows and column.
ds indicates the pool region size.
......@@ -176,10 +178,15 @@ class DownsampleFactorMax(Op):
st = ds
self.st = tuple(st)
self.ignore_border = ignore_border
self.padding = tuple(padding)
self.padding = padding
if padding != (0,0) and not ignore_border:
raise NotImplementedError('padding works only with ignore_boarder=True')
if self.padding[0] >= self.st[0] or self.padding[1] >= self.st[1]:
raise NotImplementedError('padding_h and padding_w must be smaller than strides')
def __str__(self):
return '%s{%s,%s,%s}' % (self.__class__.__name__,
self.ds, self.st, self.ignore_border)
return '%s{%s,%s,%s,%s}' % (self.__class__.__name__,
self.ds, self.st, self.ignore_border,self.padding)
def make_node(self, x):
if x.type.ndim != 4:
......@@ -195,7 +202,7 @@ class DownsampleFactorMax(Op):
if len(x.shape) != 4:
raise NotImplementedError(
'DownsampleFactorMax requires 4D input for now')
z_shape = self.out_shape(x.shape, self.ds, self.ignore_border, self.st)
z_shape = self.out_shape(x.shape, self.ds, self.ignore_border, self.st,self.padding)
if (z[0] is None) or (z[0].shape != z_shape):
z[0] = numpy.empty(self.out_shape(x.shape, self.ds,
self.ignore_border, self.st),
......@@ -208,9 +215,31 @@ class DownsampleFactorMax(Op):
pc = zz.shape[-1]
ds0, ds1 = self.ds
st0, st1 = self.st
img_rows = x.shape[-2]
img_cols = x.shape[-1]
img_rows = x.shape[-2] + 2 * self.padding[0]
img_cols = x.shape[-1] + 2 * self.padding[1]
pad_h = self.padding[0]
pad_w = self.padding[1]
def get_valid_corners(x):
# x (m,c,h,w)
img_h,img_w = x.shape[-2:]
row_st_valid = pad_h
row_end_valid = img_h + pad_h
col_st_valid = pad_w
col_end_valid = img_w + pad_w
return row_st_valid, row_end_valid, col_st_valid, col_end_valid
row_st_valid, row_end_valid, col_st_valid, col_end_valid = get_valid_corners(x)
def shrink(row_st, row_end, col_st, col_end):
# this will shrink the pooling region such that padded areas are ignored
# when performing max
if row_st <= row_st_valid:
row_st = row_st_valid
if row_end >= row_end_valid:
row_end = row_end_valid
if col_st <= col_st_valid:
col_st = col_st_valid
if col_end >= col_end_valid:
col_end = col_end_valid
return row_st, row_end, col_st, col_end
for n in xrange(x.shape[0]):
for k in xrange(x.shape[1]):
for r in xrange(pr):
......@@ -219,6 +248,8 @@ class DownsampleFactorMax(Op):
for c in xrange(pc):
col_st = c * st1
col_end = __builtin__.min(col_st + ds1, img_cols)
row_st, row_end, col_st, col_end = shrink(
row_st, row_end, col_st, col_end)
zz[n, k, r, c] = x[
n, k, row_st:row_end, col_st:col_end].max()
......@@ -320,16 +351,17 @@ class DownsampleFactorMax(Op):
class DownsampleFactorMaxGrad(Op):
__props__ = ('ds', 'ignore_border', 'st')
def __init__(self, ds, ignore_border, st=None):
def __init__(self, ds, ignore_border, st=None, padding=(0,0)):
self.ds = tuple(ds)
self.ignore_border = ignore_border
if st is None:
st = ds
self.st = tuple(st)
self.padding = tuple(padding)
def __str__(self):
return '%s{%s,%s,%s}' % (self.__class__.__name__,
self.ds, self.st, self.ignore_border)
return '%s{%s,%s,%s,%s}' % (self.__class__.__name__,
self.ds, self.st, self.ignore_border,self.padding)
def make_node(self, x, maxout, gz):
# make_node should only be called by the grad function of
......@@ -351,9 +383,31 @@ class DownsampleFactorMaxGrad(Op):
pc = maxout.shape[-1]
ds0, ds1 = self.ds
st0, st1 = self.st
img_rows = x.shape[-2]
img_cols = x.shape[-1]
img_rows = x.shape[-2] + 2 * self.padding[0]
img_cols = x.shape[-1] + 2 * self.padding[1]
pad_h = self.padding[0]
pad_w = self.padding[1]
def get_valid_corners(x):
# x (m,c,h,w)
img_h,img_w = x.shape[-2:]
row_st_valid = pad_h
row_end_valid = img_h + pad_h
col_st_valid = pad_w
col_end_valid = img_w + pad_w
return row_st_valid, row_end_valid, col_st_valid, col_end_valid
row_st_valid, row_end_valid, col_st_valid, col_end_valid = get_valid_corners(x)
def shrink(row_st, row_end, col_st, col_end):
# this will shrink the pooling region such that padded areas are ignored
# when performing max
if row_st <= row_st_valid:
row_st = row_st_valid
if row_end >= row_end_valid:
row_end = row_end_valid
if col_st <= col_st_valid:
col_st = col_st_valid
if col_end >= col_end_valid:
col_end = col_end_valid
return row_st, row_end, col_st, col_end
for n in xrange(x.shape[0]):
for k in xrange(x.shape[1]):
for r in xrange(pr):
......@@ -362,6 +416,8 @@ class DownsampleFactorMaxGrad(Op):
for c in xrange(pc):
col_st = c * st1
col_end = __builtin__.min(col_st + ds1, img_cols)
row_st, row_end, col_st, col_end = shrink(
row_st, row_end, col_st, col_end)
for row_ind in xrange(row_st, row_end):
for col_ind in xrange(col_st, col_end):
if (maxout[n, k, r, c] == x[n, k, row_ind, col_ind]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论