提交 caccc5f8 authored 作者: Vikram's avatar Vikram

Changes in corr.py and abstract_conv.py

上级 38e03e0f
...@@ -1751,21 +1751,30 @@ class BaseAbstractConv(Op): ...@@ -1751,21 +1751,30 @@ class BaseAbstractConv(Op):
filter_dilation = (1,) * convdim filter_dilation = (1,) * convdim
if isinstance(border_mode, integer_types): if isinstance(border_mode, integer_types):
border_mode = (border_mode,) * convdim if border_mode < 0:
if isinstance(border_mode, tuple): raise ValueError(
'invalid border_mode {}, which must be a '
'non-negative integer'.format(border_mode))
border_mode = ((border_mode, border_mode),) * convdim
elif isinstance(border_mode, tuple):
if len(border_mode) != convdim: if len(border_mode) != convdim:
raise ValueError( raise ValueError(
'border mode must have exactly {} values, ' 'invalid border_mode {} which must be a '
'but was {}'.format(convdim, border_mode)) 'tuple of length {}'.format(border_mode, convdim))
border_mode = tuple(map(int, border_mode)) for mode in border_mode:
if border_mode == (0,) * convdim: if not((isinstance(mode, integer_types) and mode > 0) or
border_mode = 'valid' (isinstance(mode, tuple) and len(mode) == 2 and
if not ((isinstance(border_mode, tuple) and min(border_mode) >= 0) or min(mode) >= 0)):
border_mode in ('valid', 'full', 'half')): raise ValueError(
'invalid border mode {}. The tuple can only contain '
'integers or tuples of length 2'.format(border_mode))
elif border_mode not in ('valid', 'full', 'half'):
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be either ' 'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a tuple of {}' '"valid", "full", "half", an integer or a tuple '
' integers'.format(border_mode, convdim)) 'of length {}'.format(border_mode, convdim))
if all(mode == (0, 0) or mode == 0 for mode in border_mode):
border_mode = 'valid'
self.imshp = tuple(imshp) if imshp else (None,) * (2 + convdim) self.imshp = tuple(imshp) if imshp else (None,) * (2 + convdim)
for imshp_i in self.imshp: for imshp_i in self.imshp:
...@@ -2026,26 +2035,42 @@ class AbstractConv(BaseAbstractConv): ...@@ -2026,26 +2035,42 @@ class AbstractConv(BaseAbstractConv):
o, = out_ o, = out_
mode = self.border_mode mode = self.border_mode
if not ((isinstance(mode, tuple) and min(mode) >= 0) or if isinstance(mode, tuple):
mode in ('valid', 'full', 'half')): if len(mode) != 2:
raise ValueError(
'invalid border_mode {} which must be a '
'tuple of length {}'.format(mode, self.convdim))
border = ()
for m in mode:
if isinstance(m, integer_types) and m > 0:
border += ((m, m),)
elif isinstance(m, tuple) and len(m) == 2 and \
min(m) >= 0:
border += ((int(m[0]), int(m[1])),)
else:
raise ValueError(
'invalid border mode {}. The tuple can only contain '
'integers or tuples of length 2'.format(mode))
mode = border
elif mode not in ('valid', 'full', 'half'):
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be either ' 'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a tuple of' '"valid", "full", "half", an integer or a tuple '
' integers'.format(mode)) 'of length {}'.format(mode, self.convdim))
if mode == "full": if mode == "full":
mode = tuple(dil_kernshp[i] - 1 for i in range(self.convdim)) mode = tuple((dil_kernshp[i] - 1,) * 2 for i in range(self.convdim))
elif mode == "half": elif mode == "half":
mode = tuple(dil_kernshp[i] // 2 for i in range(self.convdim)) mode = tuple((dil_kernshp[i] // 2,) * 2 for i in range(self.convdim))
if isinstance(mode, tuple): if isinstance(mode, tuple):
pad = tuple(int(mode[i]) for i in range(self.convdim)) pad = mode
mode = "valid" mode = "valid"
new_img = np.zeros((img.shape[0], img.shape[1]) + new_img = np.zeros((img.shape[0], img.shape[1]) +
tuple(img.shape[i + 2] + 2 * pad[i] tuple(img.shape[i + 2] + pad[i][0] + pad[i][1]
for i in range(self.convdim)), for i in range(self.convdim)),
dtype=img.dtype) dtype=img.dtype)
new_img[(slice(None), slice(None)) + new_img[(slice(None), slice(None)) +
tuple(slice(pad[i], img.shape[i + 2] + pad[i]) tuple(slice(pad[i][0], img.shape[i + 2] + pad[i][0])
for i in range(self.convdim))] = img for i in range(self.convdim))] = img
img = new_img img = new_img
if not self.filter_flip: if not self.filter_flip:
...@@ -2297,12 +2322,28 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -2297,12 +2322,28 @@ class AbstractConv_gradWeights(BaseAbstractConv):
o, = out_ o, = out_
mode = self.border_mode mode = self.border_mode
if not ((isinstance(mode, tuple) and min(mode) >= 0) or if isinstance(mode, tuple):
mode in ('valid', 'full', 'half')): if len(mode) != 2:
raise ValueError(
'invalid border_mode {} which must be a '
'tuple of length {}'.format(mode, self.convdim))
border = ()
for m in mode:
if isinstance(m, integer_types) and m > 0:
border += ((m, m),)
elif isinstance(m, tuple) and len(m) == 2 and \
min(m) >= 0:
border += ((int(m[0]), int(m[1])),)
else:
raise ValueError(
'invalid border mode {}. The tuple can only contain '
'integers or tuples of length 2'.format(mode))
mode = border
elif mode not in ('valid', 'full', 'half'):
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be either ' 'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a tuple of' '"valid", "full", "half", an integer or a tuple '
' integers'.format(mode)) 'of length {}'.format(mode, self.convdim))
if self.unshared and self.convdim != 2: if self.unshared and self.convdim != 2:
raise NotImplementedError('Unshared convolution not implemented for %dD' raise NotImplementedError('Unshared convolution not implemented for %dD'
% self.convdim) % self.convdim)
...@@ -2311,19 +2352,18 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -2311,19 +2352,18 @@ class AbstractConv_gradWeights(BaseAbstractConv):
for i in range(self.convdim)) for i in range(self.convdim))
if mode == "full": if mode == "full":
mode = tuple(dil_shape[i] - 1 for i in range(self.convdim)) mode = tuple((dil_shape[i] - 1,) * 2 for i in range(self.convdim))
elif mode == "half": elif mode == "half":
mode = tuple(dil_shape[i] // 2 for i in range(self.convdim)) mode = tuple((dil_shape[i] // 2,) * 2 for i in range(self.convdim))
if isinstance(mode, tuple): if isinstance(mode, tuple):
pad = tuple(int(mode[i]) for i in range(self.convdim)) pad = mode
mode = "valid" mode = "valid"
new_img = np.zeros((img.shape[0], img.shape[1]) + new_img = np.zeros((img.shape[0], img.shape[1]) +
tuple(img.shape[i + 2] + 2 * pad[i] tuple(img.shape[i + 2] + pad[i][0] + pad[i][1]
for i in range(self.convdim)), for i in range(self.convdim)),
dtype=img.dtype) dtype=img.dtype)
new_img[(slice(None), slice(None)) + new_img[(slice(None), slice(None)) +
tuple(slice(pad[i], img.shape[i + 2] + pad[i]) tuple(slice(pad[i][0], img.shape[i + 2] + pad[i][0])
for i in range(self.convdim))] = img for i in range(self.convdim))] = img
img = new_img img = new_img
...@@ -2612,12 +2652,28 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2612,12 +2652,28 @@ class AbstractConv_gradInputs(BaseAbstractConv):
o, = out_ o, = out_
mode = self.border_mode mode = self.border_mode
if not ((isinstance(mode, tuple) and min(mode) >= 0) or if isinstance(mode, tuple):
mode in ('valid', 'full', 'half')): if len(mode) != 2:
raise ValueError(
'invalid border_mode {} which must be a '
'tuple of length {}'.format(mode, self.convdim))
border = ()
for m in mode:
if isinstance(m, integer_types) and m > 0:
border += ((m, m),)
elif isinstance(m, tuple) and len(m) == 2 and \
min(m) >= 0:
border += ((int(m[0]), int(m[1])),)
else:
raise ValueError(
'invalid border mode {}. The tuple can only contain '
'integers or tuples of length 2'.format(mode))
mode = border
elif mode not in ('valid', 'full', 'half'):
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be either ' 'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a tuple of' '"valid", "full", "half", an integer or a tuple '
' integers'.format(mode)) 'of length {}'.format(mode, self.convdim))
if self.unshared and self.convdim != 2: if self.unshared and self.convdim != 2:
raise NotImplementedError('Unshared convolution not implemented for %dD' raise NotImplementedError('Unshared convolution not implemented for %dD'
% self.convdim) % self.convdim)
...@@ -2642,14 +2698,14 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2642,14 +2698,14 @@ class AbstractConv_gradInputs(BaseAbstractConv):
pad = (0,) * self.convdim pad = (0,) * self.convdim
if mode == "full": if mode == "full":
pad = tuple(dil_kernshp[i] - 1 for i in range(self.convdim)) pad = tuple((dil_kernshp[i] - 1,) * 2 for i in range(self.convdim))
elif mode == "half": elif mode == "half":
pad = tuple(dil_kernshp[i] // 2 for i in range(self.convdim)) pad = tuple((dil_kernshp[i] // 2,) * 2 for i in range(self.convdim))
elif isinstance(mode, tuple): elif isinstance(mode, tuple):
pad = tuple(mode[i] for i in range(self.convdim)) pad = mode
if any(self.subsample[i] > 1 for i in range(self.convdim)): if any(self.subsample[i] > 1 for i in range(self.convdim)):
new_shape = ((topgrad.shape[0], topgrad.shape[1]) + new_shape = ((topgrad.shape[0], topgrad.shape[1]) +
tuple(shape[i] + 2 * pad[i] - dil_kernshp[i] + 1 tuple(shape[i] + pad[i][0] + pad[i][1] - dil_kernshp[i] + 1
for i in range(self.convdim))) for i in range(self.convdim)))
new_topgrad = np.zeros((new_shape), dtype=topgrad.dtype) new_topgrad = np.zeros((new_shape), dtype=topgrad.dtype)
new_topgrad[(slice(None), slice(None)) + new_topgrad[(slice(None), slice(None)) +
...@@ -2705,9 +2761,9 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2705,9 +2761,9 @@ class AbstractConv_gradInputs(BaseAbstractConv):
if self.filter_flip: if self.filter_flip:
img = img[flip_filters] img = img[flip_filters]
if any(p > 0 for p in pad): if any(p != (0, 0) or p != 0 for p in pad):
img = img[(slice(None), slice(None)) + img = img[(slice(None), slice(None)) +
tuple(slice(pad[i], img.shape[i + 2] - pad[i]) tuple(slice(pad[i][0], img.shape[i + 2] - pad[i][0])
for i in range(self.convdim))] for i in range(self.convdim))]
o[0] = node.outputs[0].type.filter(img) o[0] = node.outputs[0].type.filter(img)
......
...@@ -55,7 +55,8 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -55,7 +55,8 @@ class BaseCorrMM(gof.OpenMPOp):
('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')), # 2 ('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')), # 2
dH=int64, dW=int64, dH=int64, dW=int64,
dilH=int64, dilW=int64, dilH=int64, dilW=int64,
padH=int64, padW=int64, padH_l=int64, padH_r=int64,
padW_l=int64, padW_r=int64,
num_groups=int64, unshared=int8) num_groups=int64, unshared=int8)
def __init__(self, border_mode="valid", subsample=(1, 1), def __init__(self, border_mode="valid", subsample=(1, 1),
...@@ -78,7 +79,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -78,7 +79,7 @@ class BaseCorrMM(gof.OpenMPOp):
border += ((mode, mode),) border += ((mode, mode),)
elif isinstance(mode, tuple) and len(mode) == 2 and \ elif isinstance(mode, tuple) and len(mode) == 2 and \
min(mode) >= 0: min(mode) >= 0:
border = ((mode[0], mode[1]),) border += ((int(mode[0]), int(mode[1])),)
else: else:
raise ValueError( raise ValueError(
'invalid border mode {}. The tuple can only contain ' 'invalid border mode {}. The tuple can only contain '
...@@ -347,7 +348,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -347,7 +348,7 @@ class BaseCorrMM(gof.OpenMPOp):
// kernel height is specified (perhaps vertical subsampling or half padding) // kernel height is specified (perhaps vertical subsampling or half padding)
kH = %(height)s; kH = %(height)s;
} }
else if (padH == -2) { else if (padH_l == -2 || padH_r == -2) {
// vertical full padding, we can infer the kernel height // vertical full padding, we can infer the kernel height
kH = (2 - PyArray_DIMS(bottom)[2] + (PyArray_DIMS(top)[2] - 1) * dH - 1)/ dilH + 1; kH = (2 - PyArray_DIMS(bottom)[2] + (PyArray_DIMS(top)[2] - 1) * dH - 1)/ dilH + 1;
} }
...@@ -359,7 +360,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -359,7 +360,7 @@ class BaseCorrMM(gof.OpenMPOp):
// kernel width is specified (perhaps horizontal subsampling or half padding) // kernel width is specified (perhaps horizontal subsampling or half padding)
kW = %(width)s; kW = %(width)s;
} }
else if (padW == -2) { else if (padW_l == -2 || padW_r == -2) {
kW = (2 - PyArray_DIMS(bottom)[3] + (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1; kW = (2 - PyArray_DIMS(bottom)[3] + (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
} }
else { else {
...@@ -372,24 +373,24 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -372,24 +373,24 @@ class BaseCorrMM(gof.OpenMPOp):
dil_kW = (kW - 1) * dilW + 1; dil_kW = (kW - 1) * dilW + 1;
// Auto-padding if requested // Auto-padding if requested
if (padH == -1) { // vertical half padding if (padH_l == -1 || padH_r == -1) { // vertical half padding
padH = dil_kH / 2; padH_l = padH_r = dil_kH / 2;
} }
else if (padH == -2) { // vertical full padding else if (padH_l == -2 || padH_r == -2) { // vertical full padding
padH = dil_kH - 1; padH_l = padH_r = dil_kH - 1;
} }
else if (padH < 0) { else if (padH_l < -2 || padH_r < -2) {
PyErr_SetString(PyExc_ValueError, "BaseCorrMM: padH must be >= -2"); PyErr_SetString(PyExc_ValueError, "BaseCorrMM: padH_l and padH_r must be >= -2");
%(fail)s %(fail)s
} }
if (padW == -1) { // horizontal half padding if (padW_l == -1 || padW_r == -1) { // horizontal half padding
padW = dil_kW / 2; padW_l = padW_r = dil_kW / 2;
} }
else if (padW == -2) { // horizontal full padding else if (padW_l == -2 || padW_r == -2) { // horizontal full padding
padW = dil_kW - 1; padW_l = padW_r = dil_kW - 1;
} }
else if (padW < 0) { else if (padW_l < -2 || padW_r < -2) {
PyErr_SetString(PyExc_ValueError, "BaseCorrMM: padW must be >= -2"); PyErr_SetString(PyExc_ValueError, "BaseCorrMM: padW_l and padW_r must be >= -2");
%(fail)s %(fail)s
} }
...@@ -720,14 +721,14 @@ class CorrMM_gradWeights(BaseCorrMM): ...@@ -720,14 +721,14 @@ class CorrMM_gradWeights(BaseCorrMM):
def infer_shape(self, node, input_shape): def infer_shape(self, node, input_shape):
if self.border_mode == "half": if self.border_mode == "half":
padH = padW = -1 padH_l = padH_r = padW_l = padW_r = -1
elif self.border_mode == "full": elif self.border_mode == "full":
padH = padW = -2 padH_l = padH_r = padW_l = padW_r = -2
elif isinstance(self.border_mode, tuple): elif isinstance(self.border_mode, tuple):
padH, padW = self.border_mode (padH_l, padH_r), (padW_l, padW_r) = self.border_mode
else: else:
assert self.border_mode == "valid" assert self.border_mode == "valid"
padH = padW = 0 padH_l = padH_r = padW_l = padW_r = 0
dH, dW = self.subsample dH, dW = self.subsample
imshp = input_shape[0] imshp = input_shape[0]
topshp = input_shape[1] topshp = input_shape[1]
...@@ -735,21 +736,21 @@ class CorrMM_gradWeights(BaseCorrMM): ...@@ -735,21 +736,21 @@ class CorrMM_gradWeights(BaseCorrMM):
ssize = ssize // self.num_groups ssize = ssize // self.num_groups
nkern, topshp = topshp[1], list(topshp[2:]) nkern, topshp = topshp[1], list(topshp[2:])
height_width = node.inputs[-2:] height_width = node.inputs[-2:]
if ((dH != 1) or (padH == -1)): if ((dH != 1) or (padH_l == -1) or (padH_r == -1)):
# vertical subsampling or half padding, kernel height is specified # vertical subsampling or half padding, kernel height is specified
kH = height_width[0] kH = height_width[0]
elif padH == -2: elif (padH_l == -2) or (padH_r == -2):
# vertical full padding, we can infer the kernel height # vertical full padding, we can infer the kernel height
kH = 2 - imshp[0] + (topshp[0] - 1) * dH kH = 2 - imshp[0] + (topshp[0] - 1) * dH
else: else:
# explicit padding, we can infer the kernel height # explicit padding, we can infer the kernel height
kH = imshp[0] + 2 * padH - (topshp[0] - 1) * dH kH = imshp[0] + padH_l + padH_r - (topshp[0] - 1) * dH
if ((dW != 1) or (padW == -1)): if ((dW != 1) or (padW_l == -1) or (padW_r == -1)):
kW = height_width[1] kW = height_width[1]
elif (padW == -2): elif (padW_l == -2) or (padW_r == -2):
kW = 2 - imshp[1] + (topshp[1] - 1) * dW kW = 2 - imshp[1] + (topshp[1] - 1) * dW
else: else:
kW = imshp[1] + 2 * padW - (topshp[1] - 1) * dW kW = imshp[1] + padW_l + padW_r - (topshp[1] - 1) * dW
if self.unshared is True: if self.unshared is True:
return [(nkern, topshp[0], topshp[1], ssize, kH, kW)] return [(nkern, topshp[0], topshp[1], ssize, kH, kW)]
else: else:
...@@ -834,14 +835,14 @@ class CorrMM_gradInputs(BaseCorrMM): ...@@ -834,14 +835,14 @@ class CorrMM_gradInputs(BaseCorrMM):
def infer_shape(self, node, input_shape): def infer_shape(self, node, input_shape):
if self.border_mode == "half": if self.border_mode == "half":
padH = padW = -1 padH_l = padH_r = padW_l = padW_r = -1
elif self.border_mode == "full": elif self.border_mode == "full":
padH = padW = -2 padH_l = padH_r = padW_l = padW_r = -2
elif isinstance(self.border_mode, tuple): elif isinstance(self.border_mode, tuple):
padH, padW = self.border_mode (padH_l, padH_r), (padW_l, padW_r) = self.border_mode
else: else:
assert self.border_mode == "valid" assert self.border_mode == "valid"
padH = padW = 0 padH_l = padH_r = padW_l = padW_r = 0
dH, dW = self.subsample dH, dW = self.subsample
kshp = input_shape[0] kshp = input_shape[0]
topshp = input_shape[1] topshp = input_shape[1]
...@@ -849,27 +850,27 @@ class CorrMM_gradInputs(BaseCorrMM): ...@@ -849,27 +850,27 @@ class CorrMM_gradInputs(BaseCorrMM):
ssize = ssize * self.num_groups ssize = ssize * self.num_groups
bsize, topshp = topshp[0], list(topshp[2:]) bsize, topshp = topshp[0], list(topshp[2:])
height_width = node.inputs[-2:] height_width = node.inputs[-2:]
if padH == -1: if padH_l == -1 or padH_r == -1:
padH = kshp[0] // 2 padH_l = padH_r = kshp[0] // 2
elif padH == -2: elif padH_l == -2 or padH_r == -2:
padH = kshp[0] - 1 padH_l = padH_r = kshp[0] - 1
elif padH < -2: elif padH_l < -2 or padH_r < -2:
raise ValueError('CorrMM_gradInputs: border_mode must be >= 0.') raise ValueError('CorrMM_gradInputs: border_mode must be >= 0.')
if padW == -1: if padW_l == -1 or padW_r == -1:
padW = kshp[1] // 2 padW_l = padW_r = kshp[1] // 2
elif padW == -2: elif padW_l == -2 or padW_r == -2:
padW = kshp[1] - 1 padW_l = padW_r = kshp[1] - 1
elif padW < -2: elif padW_l < -2 or padW_r < -2:
raise ValueError('CorrMM_gradInputs: border_mode must be >= 0.') raise ValueError('CorrMM_gradInputs: border_mode must be >= 0.')
if dH != 1: if dH != 1:
out_shp0 = height_width[0] out_shp0 = height_width[0]
else: else:
out_shp0 = (topshp[0] - 1) * dH + kshp[0] - 2 * padH out_shp0 = (topshp[0] - 1) * dH + kshp[0] - padH_l - padH_r
if dW != 1: if dW != 1:
out_shp1 = height_width[1] out_shp1 = height_width[1]
else: else:
out_shp1 = (topshp[1] - 1) * dW + kshp[1] - 2 * padW out_shp1 = (topshp[1] - 1) * dW + kshp[1] - padW_l - padW_r
out_shp = (out_shp0, out_shp1) out_shp = (out_shp0, out_shp1)
return [(bsize, ssize) + out_shp] return [(bsize, ssize) + out_shp]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论