提交 28f858c6 authored 作者: affanv14's avatar affanv14

modify abstractconv3d to support num_groups

上级 26d47057
...@@ -671,7 +671,8 @@ def conv3d(input, ...@@ -671,7 +671,8 @@ def conv3d(input,
border_mode='valid', border_mode='valid',
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
""" """
This function will build the symbolic graph for convolving a mini-batch of a This function will build the symbolic graph for convolving a mini-batch of a
stack of 3D inputs with a set of 3D filters. The implementation is modelled stack of 3D inputs with a set of 3D filters. The implementation is modelled
...@@ -759,7 +760,8 @@ def conv3d(input, ...@@ -759,7 +760,8 @@ def conv3d(input,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return conv_op(input, filters) return conv_op(input, filters)
...@@ -1603,8 +1605,6 @@ class BaseAbstractConv(Op): ...@@ -1603,8 +1605,6 @@ class BaseAbstractConv(Op):
self.filter_dilation = tuple(filter_dilation) self.filter_dilation = tuple(filter_dilation)
if num_groups < 1: if num_groups < 1:
raise ValueError("num_groups must have value greater than zero") raise ValueError("num_groups must have value greater than zero")
elif num_groups > 1 and convdim == 3:
raise ValueError("grouped convolution not supported for 3D convolutions")
self.num_groups = num_groups self.num_groups = num_groups
def do_constant_folding(self, node): def do_constant_folding(self, node):
...@@ -1664,7 +1664,6 @@ class BaseAbstractConv(Op): ...@@ -1664,7 +1664,6 @@ class BaseAbstractConv(Op):
tuple(slice(None, None, dilation[i]) for i in range(self.convdim)) tuple(slice(None, None, dilation[i]) for i in range(self.convdim))
] = kern ] = kern
if self.convdim == 2:
if img.shape[1] % self.num_groups != 0: if img.shape[1] % self.num_groups != 0:
raise ValueError( raise ValueError(
'number of input channels must be divible by num_groups') 'number of input channels must be divible by num_groups')
...@@ -1675,11 +1674,13 @@ class BaseAbstractConv(Op): ...@@ -1675,11 +1674,13 @@ class BaseAbstractConv(Op):
raise ValueError( raise ValueError(
'the number of input channels in the kernel should ' 'the number of input channels in the kernel should '
'specify the number of channels of 1 group') 'specify the number of channels of 1 group')
val = _valfrommode(mode)
bval = _bvalfromboundary('fill')
input_channel_offset = img.shape[1] // self.num_groups input_channel_offset = img.shape[1] // self.num_groups
output_channel_offset = kern.shape[0] // self.num_groups output_channel_offset = kern.shape[0] // self.num_groups
if self.convdim == 2:
val = _valfrommode(mode)
bval = _bvalfromboundary('fill')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore', np.ComplexWarning) warnings.simplefilter('ignore', np.ComplexWarning)
for b in xrange(img.shape[0]): for b in xrange(img.shape[0]):
...@@ -1692,11 +1693,12 @@ class BaseAbstractConv(Op): ...@@ -1692,11 +1693,12 @@ class BaseAbstractConv(Op):
im0, ...], 1, val, bval, 0) im0, ...], 1, val, bval, 0)
elif self.convdim == 3: elif self.convdim == 3:
for b in xrange(img.shape[0]): for b in xrange(img.shape[0]):
for n in xrange(kern.shape[0]): for g in xrange(self.num_groups):
for im0 in xrange(img.shape[1]): for n in xrange(output_channel_offset):
out[b, n, ...] += convolve(img[b, im0, ...], for im0 in xrange(input_channel_offset):
dilated_kern[n, im0, ...], out[b, g * output_channel_offset + n, ...] += convolve(img[b, g * input_channel_offset + im0, ...],
mode) dilated_kern[g * output_channel_offset + n,
im0, ...], mode)
else: else:
raise NotImplementedError('only 2D and 3D convolution are implemented') raise NotImplementedError('only 2D and 3D convolution are implemented')
return out return out
...@@ -1888,13 +1890,15 @@ class AbstractConv3d(AbstractConv): ...@@ -1888,13 +1890,15 @@ class AbstractConv3d(AbstractConv):
border_mode="valid", border_mode="valid",
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
super(AbstractConv3d, self).__init__(convdim=3, super(AbstractConv3d, self).__init__(convdim=3,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, weights = inp bottom, weights = inp
...@@ -1903,13 +1907,15 @@ class AbstractConv3d(AbstractConv): ...@@ -1903,13 +1907,15 @@ class AbstractConv3d(AbstractConv):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
weights, top, bottom.shape[-3:]) weights, top, bottom.shape[-3:])
d_weights = AbstractConv3d_gradWeights(self.imshp, self.kshp, d_weights = AbstractConv3d_gradWeights(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
bottom, top, weights.shape[-3:]) bottom, top, weights.shape[-3:])
...@@ -2033,8 +2039,8 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -2033,8 +2039,8 @@ class AbstractConv_gradWeights(BaseAbstractConv):
mshp0 = mat.shape[0] // self.num_groups mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[1] * self.num_groups mshp1 = mat.shape[1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:]) mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
mat = mat.transpose((1, 0, 2, 3, 4)) mat = mat.transpose((1, 0, 2) + tuple(range(3, 3 + self.convdim)))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-2:]) mat = mat.reshape((mshp0, mshp1) + mat.shape[-self.convdim:])
return mat return mat
if self.num_groups > 1: if self.num_groups > 1:
...@@ -2147,13 +2153,15 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights): ...@@ -2147,13 +2153,15 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights):
border_mode="valid", border_mode="valid",
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
super(AbstractConv3d_gradWeights, self).__init__(convdim=3, super(AbstractConv3d_gradWeights, self).__init__(convdim=3,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, top = inp[:2] bottom, top = inp[:2]
...@@ -2162,7 +2170,8 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights): ...@@ -2162,7 +2170,8 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(weights, self.filter_dilation,
self.num_groups)(weights,
top, top,
bottom.shape[-3:]) bottom.shape[-3:])
d_top = AbstractConv3d(self.imshp, d_top = AbstractConv3d(self.imshp,
...@@ -2170,7 +2179,8 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights): ...@@ -2170,7 +2179,8 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used # Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer # for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable. # that the dimensions are broadcastable.
...@@ -2414,13 +2424,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs): ...@@ -2414,13 +2424,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs):
border_mode="valid", border_mode="valid",
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
super(AbstractConv3d_gradInputs, self).__init__(convdim=3, super(AbstractConv3d_gradInputs, self).__init__(convdim=3,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
weights, top = inp[:2] weights, top = inp[:2]
...@@ -2429,13 +2441,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs): ...@@ -2429,13 +2441,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, top, self.filter_dilation,
self.num_groups)(bottom, top,
weights.shape[-3:]) weights.shape[-3:])
d_top = AbstractConv3d(self.imshp, self.kshp, d_top = AbstractConv3d(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used # Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer # for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable. # that the dimensions are broadcastable.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论