提交 417e6fb7 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Fix border_mode

上级 70653920
...@@ -510,8 +510,9 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -510,8 +510,9 @@ class AbstractConv2d(BaseAbstractConv2d):
o, = out_ o, = out_
mode = self.border_mode mode = self.border_mode
### Pad if mode == "full":
if mode == "half": mode = (kern.shape[2] - 1, kern.shape[3] - 1)
elif mode == "half":
mode = (kern.shape[2] // 2, kern.shape[3] // 2) mode = (kern.shape[2] // 2, kern.shape[3] // 2)
if isinstance(mode, tuple): if isinstance(mode, tuple):
pad_h, pad_w = map(int, mode) pad_h, pad_w = map(int, mode)
...@@ -521,14 +522,11 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -521,14 +522,11 @@ class AbstractConv2d(BaseAbstractConv2d):
img.shape[3] + 2 * pad_w), dtype=img.dtype) img.shape[3] + 2 * pad_w), dtype=img.dtype)
new_img[:, :, pad_h:img.shape[2]+pad_h, pad_w:img.shape[3]+pad_w] = img new_img[:, :, pad_h:img.shape[2]+pad_h, pad_w:img.shape[3]+pad_w] = img
img = new_img img = new_img
### Filter flip
if not self.filter_flip: if not self.filter_flip:
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
conv_out = self.corr2d(img, kern, mode) conv_out = self.corr2d(img, kern, mode)
### Subsample
conv_out =conv_out[:, :, ::self.subsample[0], ::self.subsample[1]] conv_out =conv_out[:, :, ::self.subsample[0], ::self.subsample[1]]
o[0] = conv_out o[0] = conv_out
...@@ -630,7 +628,9 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -630,7 +628,9 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
o, = out_ o, = out_
mode = self.border_mode mode = self.border_mode
if mode == "half": if mode == "full":
mode = (shape[0] - 1, shape[1] - 1)
elif mode == "half":
mode = (shape[0] // 2, shape[1] // 2) mode = (shape[0] // 2, shape[1] // 2)
if isinstance(mode, tuple): if isinstance(mode, tuple):
pad_h, pad_w = map(int, mode) pad_h, pad_w = map(int, mode)
...@@ -638,9 +638,11 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -638,9 +638,11 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
new_img = numpy.zeros((img.shape[0], img.shape[1], new_img = numpy.zeros((img.shape[0], img.shape[1],
img.shape[2] + 2 *pad_h, img.shape[2] + 2 *pad_h,
img.shape[3] + 2 * pad_w), dtype=img.dtype) img.shape[3] + 2 * pad_w), dtype=img.dtype)
#import pdb; pdb.set_trace()
new_img[:, :, pad_h:img.shape[2]+pad_h, pad_w:img.shape[3]+pad_w] = img new_img[:, :, pad_h:img.shape[2]+pad_h, pad_w:img.shape[3]+pad_w] = img
img = new_img img = new_img
if self.subsample[0] > 1 or self.subsample[1] > 1: if self.subsample[0] > 1 or self.subsample[1] > 1:
new_shape = (topgrad.shape[0], topgrad.shape[1], new_shape = (topgrad.shape[0], topgrad.shape[1],
img.shape[2] - shape[0] + 1, img.shape[2] - shape[0] + 1,
...@@ -752,9 +754,14 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d): ...@@ -752,9 +754,14 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
mode = self.border_mode mode = self.border_mode
pad_h, pad_w = 0, 0 pad_h, pad_w = 0, 0
if isinstance(mode, tuple):
mode = "valid" if mode == "full":
pad_h, pad_w = (kern.shape[2] - 1, kern.shape[3] - 1)
elif mode == "half":
pad_h, pad_w = (kern.shape[2] // 2, kern.shape[3] // 2)
elif isinstance(mode, tuple):
pad_h, pad_w = map(int, self.border_mode) pad_h, pad_w = map(int, self.border_mode)
mode = "valid"
if self.subsample[0] > 1 or self.subsample[1] > 1: if self.subsample[0] > 1 or self.subsample[1] > 1:
new_shape = (topgrad.shape[0], topgrad.shape[1], new_shape = (topgrad.shape[0], topgrad.shape[1],
...@@ -769,8 +776,7 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d): ...@@ -769,8 +776,7 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
img = self.corr2d(topgrad, kern, mode="full") img = self.corr2d(topgrad, kern, mode="full")
if self.filter_flip: if self.filter_flip:
img = img[:, :, ::-1, ::-1] img = img[:, :, ::-1, ::-1]
if isinstance(self.border_mode, tuple): if pad_h > 0 or pad_w > 0:
pad_h, pad_w = map(int, self.border_mode)
img = img[:, :, pad_h:img.shape[2]-pad_h, pad_w:img.shape[2]-pad_w] img = img[:, :, pad_h:img.shape[2]-pad_h, pad_w:img.shape[2]-pad_w]
o[0] = img o[0] = img
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论