提交 28f858c6 authored 作者: affanv14's avatar affanv14

modify abstractconv3d to support num_groups

上级 26d47057
......@@ -671,7 +671,8 @@ def conv3d(input,
border_mode='valid',
subsample=(1, 1, 1),
filter_flip=True,
filter_dilation=(1, 1, 1)):
filter_dilation=(1, 1, 1),
num_groups=1):
"""
This function will build the symbolic graph for convolving a mini-batch of a
stack of 3D inputs with a set of 3D filters. The implementation is modelled
......@@ -759,7 +760,8 @@ def conv3d(input,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation)
filter_dilation=filter_dilation,
num_groups=num_groups)
return conv_op(input, filters)
......@@ -1603,8 +1605,6 @@ class BaseAbstractConv(Op):
self.filter_dilation = tuple(filter_dilation)
if num_groups < 1:
raise ValueError("num_groups must have value greater than zero")
elif num_groups > 1 and convdim == 3:
raise ValueError("grouped convolution not supported for 3D convolutions")
self.num_groups = num_groups
def do_constant_folding(self, node):
......@@ -1664,21 +1664,22 @@ class BaseAbstractConv(Op):
tuple(slice(None, None, dilation[i]) for i in range(self.convdim))
] = kern
if img.shape[1] % self.num_groups != 0:
raise ValueError(
'number of input channels must be divible by num_groups')
if kern.shape[0] % self.num_groups != 0:
raise ValueError(
'number of filters must be divisible by num_groups')
if img.shape[1] // num_groups != kern.shape[1]:
raise ValueError(
'the number of input channels in the kernel should '
'specify the number of channels of 1 group')
input_channel_offset = img.shape[1] // self.num_groups
output_channel_offset = kern.shape[0] // self.num_groups
if self.convdim == 2:
if img.shape[1] % self.num_groups != 0:
raise ValueError(
'number of input channels must be divible by num_groups')
if kern.shape[0] % self.num_groups != 0:
raise ValueError(
'number of filters must be divisible by num_groups')
if img.shape[1] // num_groups != kern.shape[1]:
raise ValueError(
'the number of input channels in the kernel should '
'specify the number of channels of 1 group')
val = _valfrommode(mode)
bval = _bvalfromboundary('fill')
input_channel_offset = img.shape[1] // self.num_groups
output_channel_offset = kern.shape[0] // self.num_groups
with warnings.catch_warnings():
warnings.simplefilter('ignore', np.ComplexWarning)
......@@ -1692,11 +1693,12 @@ class BaseAbstractConv(Op):
im0, ...], 1, val, bval, 0)
elif self.convdim == 3:
for b in xrange(img.shape[0]):
for n in xrange(kern.shape[0]):
for im0 in xrange(img.shape[1]):
out[b, n, ...] += convolve(img[b, im0, ...],
dilated_kern[n, im0, ...],
mode)
for g in xrange(self.num_groups):
for n in xrange(output_channel_offset):
for im0 in xrange(input_channel_offset):
out[b, g * output_channel_offset + n, ...] += convolve(img[b, g * input_channel_offset + im0, ...],
dilated_kern[g * output_channel_offset + n,
im0, ...], mode)
else:
raise NotImplementedError('only 2D and 3D convolution are implemented')
return out
......@@ -1888,13 +1890,15 @@ class AbstractConv3d(AbstractConv):
border_mode="valid",
subsample=(1, 1, 1),
filter_flip=True,
filter_dilation=(1, 1, 1)):
filter_dilation=(1, 1, 1),
num_groups=1):
super(AbstractConv3d, self).__init__(convdim=3,
imshp=imshp, kshp=kshp,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation)
filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads):
bottom, weights = inp
......@@ -1903,13 +1907,15 @@ class AbstractConv3d(AbstractConv):
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(
self.filter_dilation,
self.num_groups)(
weights, top, bottom.shape[-3:])
d_weights = AbstractConv3d_gradWeights(self.imshp, self.kshp,
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(
self.filter_dilation,
self.num_groups)(
bottom, top, weights.shape[-3:])
......@@ -2033,8 +2039,8 @@ class AbstractConv_gradWeights(BaseAbstractConv):
mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
mat = mat.transpose((1, 0, 2, 3, 4))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-2:])
mat = mat.transpose((1, 0, 2) + tuple(range(3, 3 + self.convdim)))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-self.convdim:])
return mat
if self.num_groups > 1:
......@@ -2147,13 +2153,15 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights):
border_mode="valid",
subsample=(1, 1, 1),
filter_flip=True,
filter_dilation=(1, 1, 1)):
filter_dilation=(1, 1, 1),
num_groups=1):
super(AbstractConv3d_gradWeights, self).__init__(convdim=3,
imshp=imshp, kshp=kshp,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation)
filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads):
bottom, top = inp[:2]
......@@ -2162,15 +2170,17 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights):
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(weights,
top,
bottom.shape[-3:])
self.filter_dilation,
self.num_groups)(weights,
top,
bottom.shape[-3:])
d_top = AbstractConv3d(self.imshp,
self.kshp,
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(bottom, weights)
self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
......@@ -2414,13 +2424,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs):
border_mode="valid",
subsample=(1, 1, 1),
filter_flip=True,
filter_dilation=(1, 1, 1)):
filter_dilation=(1, 1, 1),
num_groups=1):
super(AbstractConv3d_gradInputs, self).__init__(convdim=3,
imshp=imshp, kshp=kshp,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation)
filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads):
weights, top = inp[:2]
......@@ -2429,13 +2441,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs):
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(bottom, top,
weights.shape[-3:])
self.filter_dilation,
self.num_groups)(bottom, top,
weights.shape[-3:])
d_top = AbstractConv3d(self.imshp, self.kshp,
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(bottom, weights)
self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论