提交 a2c62ae0 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

add basic group conv functionality to python implementation

上级 5d260cde
...@@ -66,7 +66,6 @@ def get_conv_output_shape(image_shape, kernel_shape, ...@@ -66,7 +66,6 @@ def get_conv_output_shape(image_shape, kernel_shape,
""" """
bsize, imshp = image_shape[0], image_shape[2:] bsize, imshp = image_shape[0], image_shape[2:]
nkern, kshp = kernel_shape[0], kernel_shape[2:] nkern, kshp = kernel_shape[0], kernel_shape[2:]
if filter_dilation is None: if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int') filter_dilation = np.ones(len(subsample), dtype='int')
...@@ -512,7 +511,8 @@ def conv2d(input, ...@@ -512,7 +511,8 @@ def conv2d(input,
border_mode='valid', border_mode='valid',
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
"""This function will build the symbolic graph for convolving a mini-batch of a """This function will build the symbolic graph for convolving a mini-batch of a
stack of 2D inputs with a set of 2D filters. The implementation is modelled stack of 2D inputs with a set of 2D filters. The implementation is modelled
after Convolutional Neural Networks (CNN). after Convolutional Neural Networks (CNN).
...@@ -527,7 +527,8 @@ def conv2d(input, ...@@ -527,7 +527,8 @@ def conv2d(input,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return conv_op(input, filters) return conv_op(input, filters)
...@@ -1396,7 +1397,7 @@ class BaseAbstractConv(Op): ...@@ -1396,7 +1397,7 @@ class BaseAbstractConv(Op):
def __init__(self, convdim, def __init__(self, convdim,
imshp=None, kshp=None, border_mode="valid", imshp=None, kshp=None, border_mode="valid",
subsample=None, filter_flip=True, filter_dilation=None): subsample=None, filter_flip=True, filter_dilation=None, num_groups=1):
self.convdim = convdim self.convdim = convdim
if convdim not in (2, 3): if convdim not in (2, 3):
...@@ -1458,6 +1459,11 @@ class BaseAbstractConv(Op): ...@@ -1458,6 +1459,11 @@ class BaseAbstractConv(Op):
if len(filter_dilation) != convdim: if len(filter_dilation) != convdim:
raise ValueError("filter_dilation must have {} elements".format(convdim)) raise ValueError("filter_dilation must have {} elements".format(convdim))
self.filter_dilation = tuple(filter_dilation) self.filter_dilation = tuple(filter_dilation)
if num_groups < 1:
raise ValueError("num_groups must have value greater than zero")
elif num_groups > 1 and convdim == 3:
raise ValueError("grouped convolution not supported for 3D convolutions")
self.num_groups = num_groups
def do_constant_folding(self, node): def do_constant_folding(self, node):
# Disable constant folding since there is no implementation. # Disable constant folding since there is no implementation.
...@@ -1467,6 +1473,9 @@ class BaseAbstractConv(Op): ...@@ -1467,6 +1473,9 @@ class BaseAbstractConv(Op):
def flops(self, inp, outp): def flops(self, inp, outp):
""" Useful with the hack in profiling to print the MFlops""" """ Useful with the hack in profiling to print the MFlops"""
if self.convdim == 2: if self.convdim == 2:
if self.num_groups > 1:
raise NotImplementedError(
'flops not implemented for grouped convolution')
# if the output shape is correct, then this gives the correct # if the output shape is correct, then this gives the correct
# flops for any direction, sampling, padding, and border mode # flops for any direction, sampling, padding, and border mode
inputs, filters = inp inputs, filters = inp
...@@ -1484,7 +1493,7 @@ class BaseAbstractConv(Op): ...@@ -1484,7 +1493,7 @@ class BaseAbstractConv(Op):
raise NotImplementedError( raise NotImplementedError(
'flops not implemented for convdim={}', self.convdim) 'flops not implemented for convdim={}', self.convdim)
def conv(self, img, kern, mode="valid", dilation=1): def conv(self, img, kern, mode="valid", dilation=1, num_groups=1):
""" """
Basic slow Python 2D or 3D convolution for DebugMode Basic slow Python 2D or 3D convolution for DebugMode
""" """
...@@ -1519,16 +1528,19 @@ class BaseAbstractConv(Op): ...@@ -1519,16 +1528,19 @@ class BaseAbstractConv(Op):
if self.convdim == 2: if self.convdim == 2:
val = _valfrommode(mode) val = _valfrommode(mode)
bval = _bvalfromboundary('fill') bval = _bvalfromboundary('fill')
input_channel_offset = img.shape[1] // self.num_groups
output_channel_offset = kern.shape[0] // self.num_groups
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore', np.ComplexWarning) warnings.simplefilter('ignore', np.ComplexWarning)
for b in xrange(img.shape[0]): for b in xrange(img.shape[0]):
for n in xrange(kern.shape[0]): for g in xrange(self.num_groups):
for im0 in xrange(img.shape[1]): for n in xrange(output_channel_offset):
# some cast generates a warning here for im0 in xrange(input_channel_offset):
out[b, n, ...] += _convolve2d(img[b, im0, ...], # some cast generates a warning here
dilated_kern[n, im0, ...], out[b, g * output_channel_offset + n, ...] += _convolve2d(img[b, g * input_channel_offset + im0, ...],
1, val, bval, 0) dilated_kern[g * output_channel_offset + n,
im0, ...], 1, val, bval, 0)
elif self.convdim == 3: elif self.convdim == 3:
for b in xrange(img.shape[0]): for b in xrange(img.shape[0]):
for n in xrange(kern.shape[0]): for n in xrange(kern.shape[0]):
...@@ -1554,13 +1566,15 @@ class AbstractConv(BaseAbstractConv): ...@@ -1554,13 +1566,15 @@ class AbstractConv(BaseAbstractConv):
border_mode="valid", border_mode="valid",
subsample=None, subsample=None,
filter_flip=True, filter_flip=True,
filter_dilation=None): filter_dilation=None,
num_groups=1):
super(AbstractConv, self).__init__(convdim=convdim, super(AbstractConv, self).__init__(convdim=convdim,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def make_node(self, img, kern): def make_node(self, img, kern):
# Make sure both inputs are Variables with the same Type # Make sure both inputs are Variables with the same Type
...@@ -1622,7 +1636,7 @@ class AbstractConv(BaseAbstractConv): ...@@ -1622,7 +1636,7 @@ class AbstractConv(BaseAbstractConv):
img = new_img img = new_img
if not self.filter_flip: if not self.filter_flip:
kern = kern[(slice(None), slice(None)) + (slice(None, None, -1),) * self.convdim] kern = kern[(slice(None), slice(None)) + (slice(None, None, -1),) * self.convdim]
conv_out = self.conv(img, kern, mode="valid", dilation=self.filter_dilation) conv_out = self.conv(img, kern, mode="valid", dilation=self.filter_dilation, num_groups=self.num_groups)
conv_out = conv_out[(slice(None), slice(None)) + conv_out = conv_out[(slice(None), slice(None)) +
tuple(slice(None, None, self.subsample[i]) tuple(slice(None, None, self.subsample[i])
for i in range(self.convdim))] for i in range(self.convdim))]
...@@ -1630,6 +1644,9 @@ class AbstractConv(BaseAbstractConv): ...@@ -1630,6 +1644,9 @@ class AbstractConv(BaseAbstractConv):
o[0] = node.outputs[0].type.filter(conv_out) o[0] = node.outputs[0].type.filter(conv_out)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if self.num_groups > 1:
raise NotImplementedError(
'Rop not implemented for grouped convolutions')
rval = None rval = None
if eval_points[0] is not None: if eval_points[0] is not None:
rval = self.make_node(eval_points[0], inputs[1]).outputs[0] rval = self.make_node(eval_points[0], inputs[1]).outputs[0]
...@@ -1668,13 +1685,15 @@ class AbstractConv2d(AbstractConv): ...@@ -1668,13 +1685,15 @@ class AbstractConv2d(AbstractConv):
border_mode="valid", border_mode="valid",
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
super(AbstractConv2d, self).__init__(convdim=2, super(AbstractConv2d, self).__init__(convdim=2,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, weights = inp bottom, weights = inp
...@@ -1684,13 +1703,15 @@ class AbstractConv2d(AbstractConv): ...@@ -1684,13 +1703,15 @@ class AbstractConv2d(AbstractConv):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
num_groups=self.num_groups)(
weights, top, bottom.shape[-2:], add_assert_shape=False) weights, top, bottom.shape[-2:], add_assert_shape=False)
d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp, d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
num_groups=self.num_groups)(
bottom, top, weights.shape[-2:], add_assert_shape=False) bottom, top, weights.shape[-2:], add_assert_shape=False)
...@@ -1772,13 +1793,15 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -1772,13 +1793,15 @@ class AbstractConv_gradWeights(BaseAbstractConv):
border_mode="valid", border_mode="valid",
subsample=None, subsample=None,
filter_flip=True, filter_flip=True,
filter_dilation=None): filter_dilation=None,
num_groups=1):
super(AbstractConv_gradWeights, self).__init__(convdim=convdim, super(AbstractConv_gradWeights, self).__init__(convdim=convdim,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
# Update shape/height_width # Update shape/height_width
def make_node(self, img, topgrad, shape, add_assert_shape=True): def make_node(self, img, topgrad, shape, add_assert_shape=True):
...@@ -1856,7 +1879,19 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -1856,7 +1879,19 @@ class AbstractConv_gradWeights(BaseAbstractConv):
(slice(None, None, -1),) * self.convdim) (slice(None, None, -1),) * self.convdim)
topgrad = topgrad.transpose(axes_order)[flip_filters] topgrad = topgrad.transpose(axes_order)[flip_filters]
img = img.transpose(axes_order) img = img.transpose(axes_order)
kern = self.conv(img, topgrad, mode="valid")
def correct_for_groups(mat):
mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
mat = mat.transpose((1, 0, 2, 3, 4))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-2:])
return mat
if self.num_groups > 1:
img = correct_for_groups(img)
kern = self.conv(img, topgrad, mode="valid", num_groups=self.num_groups)
if any(self.filter_dilation[i] > 1 for i in range(self.convdim)): if any(self.filter_dilation[i] > 1 for i in range(self.convdim)):
kern = kern[(slice(None), slice(None)) + kern = kern[(slice(None), slice(None)) +
tuple(slice(None, None, self.filter_dilation[i]) tuple(slice(None, None, self.filter_dilation[i])
...@@ -1901,13 +1936,15 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights): ...@@ -1901,13 +1936,15 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights):
border_mode="valid", border_mode="valid",
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
super(AbstractConv2d_gradWeights, self).__init__(convdim=2, super(AbstractConv2d_gradWeights, self).__init__(convdim=2,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, top = inp[:2] bottom, top = inp[:2]
...@@ -2011,13 +2048,15 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2011,13 +2048,15 @@ class AbstractConv_gradInputs(BaseAbstractConv):
border_mode="valid", border_mode="valid",
subsample=None, subsample=None,
filter_flip=True, filter_flip=True,
filter_dilation=None): filter_dilation=None,
num_groups=1):
super(AbstractConv_gradInputs, self).__init__(convdim=convdim, super(AbstractConv_gradInputs, self).__init__(convdim=convdim,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
# Update shape/height_width # Update shape/height_width
def make_node(self, kern, topgrad, shape, add_assert_shape=True): def make_node(self, kern, topgrad, shape, add_assert_shape=True):
...@@ -2097,10 +2136,20 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2097,10 +2136,20 @@ class AbstractConv_gradInputs(BaseAbstractConv):
axes_order = (1, 0) + tuple(range(2, self.convdim + 2)) axes_order = (1, 0) + tuple(range(2, self.convdim + 2))
flip_filters = ((slice(None), slice(None)) + flip_filters = ((slice(None), slice(None)) +
(slice(None, None, -1),) * self.convdim) (slice(None, None, -1),) * self.convdim)
def correct_for_groups(mat):
mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
mat = mat.transpose((1, 0, 2, 3, 4))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-2:])
return mat
kern = correct_for_groups(kern)
kern = kern.transpose(axes_order) kern = kern.transpose(axes_order)
if self.filter_flip: if self.filter_flip:
topgrad = topgrad[flip_filters] topgrad = topgrad[flip_filters]
img = self.conv(topgrad, kern, mode="full", dilation=self.filter_dilation) img = self.conv(topgrad, kern, mode="full", dilation=self.filter_dilation, num_groups=self.num_groups)
if self.filter_flip: if self.filter_flip:
img = img[flip_filters] img = img[flip_filters]
if any(p > 0 for p in pad): if any(p > 0 for p in pad):
...@@ -2144,13 +2193,15 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs): ...@@ -2144,13 +2193,15 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs):
border_mode="valid", border_mode="valid",
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
super(AbstractConv2d_gradInputs, self).__init__(convdim=2, super(AbstractConv2d_gradInputs, self).__init__(convdim=2,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
weights, top = inp[:2] weights, top = inp[:2]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论