提交 bbc749ab authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Handle filter_dilation in the abstractconv_grad* perform methods

上级 bcd856b4
......@@ -972,6 +972,9 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
mode = (shape[0] // 2, shape[1] // 2)
if isinstance(mode, tuple):
pad_h, pad_w = map(int, mode)
pad_h = (pad_h - 1) * self.filter_dilation[0] + 1
pad_w = (pad_w - 1) * self.filter_dilation[1] + 1
mode = "valid"
new_img = numpy.zeros((img.shape[0], img.shape[1],
img.shape[2] + 2 * pad_h,
......@@ -980,9 +983,11 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
img = new_img
if self.subsample[0] > 1 or self.subsample[1] > 1:
dil_shape = ((shape[0] - 1) * self.filter_dilation[0] + 1,
(shape[1] - 1) * self.filter_dilation[1] + 1)
new_shape = (topgrad.shape[0], topgrad.shape[1],
img.shape[2] - shape[0] + 1,
img.shape[3] - shape[1] + 1)
img.shape[2] - dil_shape[0] + 1,
img.shape[3] - dil_shape[1] + 1)
new_topgrad = numpy.zeros((new_shape), dtype=topgrad.dtype)
new_topgrad[:, :, ::self.subsample[0], ::self.subsample[1]] = topgrad
topgrad = new_topgrad
......@@ -990,6 +995,8 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
topgrad = topgrad.transpose(1, 0, 2, 3)[:, :, ::-1, ::-1]
img = img.transpose(1, 0, 2, 3)
kern = self.conv2d(img, topgrad, mode="valid")
if self.filter_dilation[0] > 1 or self.filter_dilation[1] > 1:
kern = kern[:, :, ::self.filter_dilation[0], ::self.filter_dilation[1]]
if self.filter_flip:
kern = kern.transpose(1, 0, 2, 3)[:, :, ::-1, ::-1]
else:
......@@ -1103,24 +1110,26 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
'"valid", "full", "half", an integer or a pair of'
' integers'.format(mode))
dil_kernshp = ((kern.shape[2] - 1) * self.filter_dilation[0] + 1,
(kern.shape[3] - 1) * self.filter_dilation[1] + 1)
pad_h, pad_w = 0, 0
if mode == "full":
pad_h, pad_w = (kern.shape[2] - 1, kern.shape[3] - 1)
pad_h, pad_w = (dil_kernshp[0] - 1, dil_kernshp[0] - 1)
elif mode == "half":
pad_h, pad_w = (kern.shape[2] // 2, kern.shape[3] // 2)
pad_h, pad_w = (dil_kernshp[1] // 2, dil_kernshp[1] // 2)
elif isinstance(mode, tuple):
pad_h, pad_w = map(int, self.border_mode)
if self.subsample[0] > 1 or self.subsample[1] > 1:
new_shape = (topgrad.shape[0], topgrad.shape[1],
shape[0] + 2 * pad_h - kern.shape[2] + 1,
shape[1] + 2 * pad_w - kern.shape[3] + 1)
shape[0] + 2 * pad_h - dil_kernshp[0] + 1,
shape[1] + 2 * pad_w - dil_kernshp[1] + 1)
new_topgrad = numpy.zeros((new_shape), dtype=topgrad.dtype)
new_topgrad[:, :, ::self.subsample[0], ::self.subsample[1]] = topgrad
topgrad = new_topgrad
kern = kern.transpose(1, 0, 2, 3)
if self.filter_flip:
topgrad = topgrad[:, :, ::-1, ::-1]
img = self.conv2d(topgrad, kern, mode="full")
img = self.conv2d(topgrad, kern, mode="full", dilation=self.filter_dilation)
if self.filter_flip:
img = img[:, :, ::-1, ::-1]
if pad_h > 0 or pad_w > 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论