提交 9daf7acc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fixes and cleanup for maxpool with padding

上级 59faec8e
...@@ -1603,7 +1603,6 @@ def local_gpu_downsample_factor_max(node): ...@@ -1603,7 +1603,6 @@ def local_gpu_downsample_factor_max(node):
assert node.op.__props__ == ('ds', 'ignore_border', 'st', 'padding') assert node.op.__props__ == ('ds', 'ignore_border', 'st', 'padding')
if node.op.padding != (0, 0): if node.op.padding != (0, 0):
return return
assert node.op.__props__ == ('ds', 'ignore_border', 'st')
x, = node.inputs x, = node.inputs
if (x.owner and isinstance(x.owner.op, HostFromGpu)): if (x.owner and isinstance(x.owner.op, HostFromGpu)):
gpu_ds = GpuDownsampleFactorMax(node.op.ds, node.op.ignore_border) gpu_ds = GpuDownsampleFactorMax(node.op.ds, node.op.ignore_border)
......
...@@ -48,7 +48,7 @@ def max_pool_2d(input, ds, ignore_border=False, st=None, padding=(0, 0)): ...@@ -48,7 +48,7 @@ def max_pool_2d(input, ds, ignore_border=False, st=None, padding=(0, 0)):
if input.ndim < 2: if input.ndim < 2:
raise NotImplementedError('max_pool_2d requires a dimension >= 2') raise NotImplementedError('max_pool_2d requires a dimension >= 2')
if input.ndim == 4: if input.ndim == 4:
op = DownsampleFactorMax(ds, ignore_border, st=st) op = DownsampleFactorMax(ds, ignore_border, st=st, padding=padding)
output = op(input) output = op(input)
return output return output
...@@ -193,10 +193,9 @@ class DownsampleFactorMax(Op): ...@@ -193,10 +193,9 @@ class DownsampleFactorMax(Op):
self.st = tuple(st) self.st = tuple(st)
self.ignore_border = ignore_border self.ignore_border = ignore_border
self.padding = tuple(padding) self.padding = tuple(padding)
self.padding = padding if self.padding != (0, 0) and not ignore_border:
if padding != (0, 0) and not ignore_border:
raise NotImplementedError( raise NotImplementedError(
'padding works only with ignore_boarder=True') 'padding works only with ignore_border=True')
if self.padding[0] >= self.ds[0] or self.padding[1] >= self.ds[1]: if self.padding[0] >= self.ds[0] or self.padding[1] >= self.ds[1]:
raise NotImplementedError( raise NotImplementedError(
'padding_h and padding_w must be smaller than strides') 'padding_h and padding_w must be smaller than strides')
...@@ -213,8 +212,6 @@ class DownsampleFactorMax(Op): ...@@ -213,8 +212,6 @@ class DownsampleFactorMax(Op):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def perform(self, node, inp, out): def perform(self, node, inp, out):
"""
"""
x, = inp x, = inp
z, = out z, = out
if len(x.shape) != 4: if len(x.shape) != 4:
...@@ -238,7 +235,6 @@ class DownsampleFactorMax(Op): ...@@ -238,7 +235,6 @@ class DownsampleFactorMax(Op):
pad_w = self.padding[1] pad_w = self.padding[1]
img_rows = x.shape[-2] + 2 * pad_h img_rows = x.shape[-2] + 2 * pad_h
img_cols = x.shape[-1] + 2 * pad_w img_cols = x.shape[-1] + 2 * pad_w
# pad the image # pad the image
fill = x.min()-1. fill = x.min()-1.
...@@ -391,7 +387,7 @@ class DownsampleFactorMaxGrad(Op): ...@@ -391,7 +387,7 @@ class DownsampleFactorMaxGrad(Op):
pad_w = self.padding[1] pad_w = self.padding[1]
img_rows = x.shape[-2] + 2 * pad_h img_rows = x.shape[-2] + 2 * pad_h
img_cols = x.shape[-1] + 2 * pad_w img_cols = x.shape[-1] + 2 * pad_w
# pad the image # pad the image
fill = x.min()-1 fill = x.min()-1
y = numpy.zeros( y = numpy.zeros(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论