提交 c2e14ce1 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5991 from affanv14/group

Implement Grouped Convolutions
...@@ -496,13 +496,16 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -496,13 +496,16 @@ class BaseGpuCorrMM(CGpuKernelBase):
Perform subsampling of the output (default: (1, 1)). Perform subsampling of the output (default: (1, 1)).
filter_dilation filter_dilation
Perform subsampling of the input, also known as dilation (default: (1, 1)). Perform subsampling of the input, also known as dilation (default: (1, 1)).
num_groups :
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately (default : 1).
""" """
check_broadcast = False check_broadcast = False
__props__ = ('border_mode', 'subsample', 'filter_dilation') __props__ = ('border_mode', 'subsample', 'filter_dilation', 'num_groups')
_f16_ok = True _f16_ok = True
def __init__(self, border_mode="valid", subsample=(1, 1), def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1)): filter_dilation=(1, 1), num_groups=1):
if isinstance(border_mode, integer_types): if isinstance(border_mode, integer_types):
border_mode = (border_mode, border_mode) border_mode = (border_mode, border_mode)
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
...@@ -521,6 +524,9 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -521,6 +524,9 @@ class BaseGpuCorrMM(CGpuKernelBase):
raise ValueError("filter_dilation must have two elements") raise ValueError("filter_dilation must have two elements")
self.subsample = tuple(subsample) self.subsample = tuple(subsample)
self.filter_dilation = tuple(filter_dilation) self.filter_dilation = tuple(filter_dilation)
if num_groups < 1:
raise ValueError("Number of groups should be greater than 0")
self.num_groups = num_groups
CGpuKernelBase.__init__(self, ['corr_gemm.c']) CGpuKernelBase.__init__(self, ['corr_gemm.c'])
@property @property
...@@ -530,11 +536,17 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -530,11 +536,17 @@ class BaseGpuCorrMM(CGpuKernelBase):
return (0, 0) return (0, 0)
def __str__(self): def __str__(self):
return '%s{%s, %s, %s}' % ( return '%s{%s, %s, %s, %s}' % (
self.__class__.__name__, self.__class__.__name__,
self.border_mode, self.border_mode,
str(self.subsample), str(self.subsample),
str(self.filter_dilation)) str(self.filter_dilation),
str(self.num_groups))
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def flops(self, inp, outp): def flops(self, inp, outp):
""" """
...@@ -562,7 +574,7 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -562,7 +574,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
def c_code_cache_version(self): def c_code_cache_version(self):
# Raise this whenever modifying the C code (including the file). # Raise this whenever modifying the C code (including the file).
return (8,) return (9,)
def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None): def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None):
""" """
...@@ -609,6 +621,7 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -609,6 +621,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
""" """
dH, dW = self.subsample dH, dW = self.subsample
dilH, dilW = self.filter_dilation dilH, dilW = self.filter_dilation
numgroups = self.num_groups
if self.border_mode == "half": if self.border_mode == "half":
padH = padW = -1 padH = padW = -1
elif self.border_mode == "full": elif self.border_mode == "full":
...@@ -669,6 +682,7 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -669,6 +682,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
size_t dilW = %(dilW)s; size_t dilW = %(dilW)s;
int padH = %(padH)s; int padH = %(padH)s;
int padW = %(padW)s; int padW = %(padW)s;
int numgroups = %(numgroups)s;
PyGpuArrayObject * bottom = %(bottom)s; PyGpuArrayObject * bottom = %(bottom)s;
PyGpuArrayObject * weights = %(weights)s; PyGpuArrayObject * weights = %(weights)s;
...@@ -768,7 +782,7 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -768,7 +782,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
// output is weights: (num_filters, num_channels, height, width) // output is weights: (num_filters, num_channels, height, width)
// height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1 // height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1
out_dim[0] = PyGpuArray_DIMS(top)[1]; out_dim[0] = PyGpuArray_DIMS(top)[1];
out_dim[1] = PyGpuArray_DIMS(bottom)[1]; out_dim[1] = PyGpuArray_DIMS(bottom)[1] / numgroups;
out_dim[2] = kH; // already inferred further above out_dim[2] = kH; // already inferred further above
out_dim[3] = kW; // how convenient out_dim[3] = kW; // how convenient
out_typecode = top->ga.typecode; out_typecode = top->ga.typecode;
...@@ -792,7 +806,7 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -792,7 +806,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
// output is bottom: (batchsize, num_channels, height, width) // output is bottom: (batchsize, num_channels, height, width)
// height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad // height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
out_dim[0] = PyGpuArray_DIMS(top)[0]; out_dim[0] = PyGpuArray_DIMS(top)[0];
out_dim[1] = PyGpuArray_DIMS(weights)[1]; out_dim[1] = PyGpuArray_DIMS(weights)[1] * numgroups;
out_dim[2] = (%(height)s != -1) ? %(height)s : (PyGpuArray_DIMS(top)[2] - 1) * dH + (PyGpuArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH; out_dim[2] = (%(height)s != -1) ? %(height)s : (PyGpuArray_DIMS(top)[2] - 1) * dH + (PyGpuArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH;
out_dim[3] = (%(width)s != -1) ? %(width)s : (PyGpuArray_DIMS(top)[3] - 1) * dW + (PyGpuArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW; out_dim[3] = (%(width)s != -1) ? %(width)s : (PyGpuArray_DIMS(top)[3] - 1) * dW + (PyGpuArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW;
out_typecode = top->ga.typecode; out_typecode = top->ga.typecode;
...@@ -836,7 +850,7 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -836,7 +850,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
} }
// Call GPU code // Call GPU code
out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW, padH, padW); out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW, padH, padW, numgroups);
if (out2==NULL){ if (out2==NULL){
%(fail)s %(fail)s
} }
...@@ -873,6 +887,11 @@ class GpuCorrMM(BaseGpuCorrMM): ...@@ -873,6 +887,11 @@ class GpuCorrMM(BaseGpuCorrMM):
The filter dilation operation applied to each input image. The filter dilation operation applied to each input image.
Should be a tuple with 2 elements. Should be a tuple with 2 elements.
Set to `(1, 1)` to disable filter dilation. Set to `(1, 1)` to disable filter dilation.
num_groups
The number of distinct groups the image and kernel must be
divided into.
should be an int
set to 1 to disable grouped convolution
Notes Notes
----- -----
...@@ -892,9 +911,9 @@ class GpuCorrMM(BaseGpuCorrMM): ...@@ -892,9 +911,9 @@ class GpuCorrMM(BaseGpuCorrMM):
""" """
def __init__(self, border_mode="valid", def __init__(self, border_mode="valid",
subsample=(1, 1), subsample=(1, 1),
filter_dilation=(1, 1)): filter_dilation=(1, 1), num_groups=1):
super(GpuCorrMM, self).__init__(border_mode, subsample, super(GpuCorrMM, self).__init__(border_mode, subsample,
filter_dilation) filter_dilation, num_groups)
def make_node(self, img, kern): def make_node(self, img, kern):
ctx_name = infer_context_name(img, kern) ctx_name = infer_context_name(img, kern)
...@@ -923,11 +942,13 @@ class GpuCorrMM(BaseGpuCorrMM): ...@@ -923,11 +942,13 @@ class GpuCorrMM(BaseGpuCorrMM):
top = gpu_contiguous(top) top = gpu_contiguous(top)
d_bottom = GpuCorrMM_gradInputs(self.border_mode, d_bottom = GpuCorrMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
weights, top, bottom.shape[-2:]) weights, top, bottom.shape[-2:])
d_weights = GpuCorrMM_gradWeights(self.border_mode, d_weights = GpuCorrMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
bottom, top, weights.shape[-2:]) bottom, top, weights.shape[-2:])
return d_bottom, d_weights return d_bottom, d_weights
...@@ -945,10 +966,11 @@ class GpuCorrMM_gradWeights(BaseGpuCorrMM): ...@@ -945,10 +966,11 @@ class GpuCorrMM_gradWeights(BaseGpuCorrMM):
def __init__(self, border_mode="valid", def __init__(self, border_mode="valid",
subsample=(1, 1), subsample=(1, 1),
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
super(GpuCorrMM_gradWeights, self).__init__(border_mode, super(GpuCorrMM_gradWeights, self).__init__(border_mode,
subsample, subsample,
filter_dilation) filter_dilation, num_groups)
def make_node(self, img, topgrad, shape=None): def make_node(self, img, topgrad, shape=None):
ctx_name = infer_context_name(img, topgrad) ctx_name = infer_context_name(img, topgrad)
...@@ -987,11 +1009,12 @@ class GpuCorrMM_gradWeights(BaseGpuCorrMM): ...@@ -987,11 +1009,12 @@ class GpuCorrMM_gradWeights(BaseGpuCorrMM):
weights = gpu_contiguous(weights) weights = gpu_contiguous(weights)
d_bottom = GpuCorrMM_gradInputs(self.border_mode, d_bottom = GpuCorrMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(weights, self.filter_dilation,
self.num_groups)(weights,
top, top,
bottom.shape[-2:]) bottom.shape[-2:])
d_top = GpuCorrMM( d_top = GpuCorrMM(
self.border_mode, self.subsample, self.filter_dilation)(bottom, weights) self.border_mode, self.subsample, self.filter_dilation, self.num_groups)(bottom, weights)
d_height_width = ( d_height_width = (
theano.gradient.DisconnectedType()(), theano.gradient.DisconnectedType()(),
) * 2 if len(inp) == 4 else () ) * 2 if len(inp) == 4 else ()
...@@ -1017,9 +1040,10 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM): ...@@ -1017,9 +1040,10 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM):
def __init__(self, border_mode="valid", def __init__(self, border_mode="valid",
subsample=(1, 1), subsample=(1, 1),
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
super(GpuCorrMM_gradInputs, self).__init__(border_mode, subsample, super(GpuCorrMM_gradInputs, self).__init__(border_mode, subsample,
filter_dilation) filter_dilation, num_groups)
def make_node(self, kern, topgrad, shape=None): def make_node(self, kern, topgrad, shape=None):
ctx_name = infer_context_name(kern, topgrad) ctx_name = infer_context_name(kern, topgrad)
...@@ -1038,6 +1062,10 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM): ...@@ -1038,6 +1062,10 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM):
assert shape[0].ndim == 0 assert shape[0].ndim == 0
assert shape[1].ndim == 0 assert shape[1].ndim == 0
if self.num_groups > 1:
broadcastable = [topgrad.type.broadcastable[0], False,
False, False]
else:
broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1], broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1],
False, False] False, False]
return Apply(self, [kern, topgrad] + height_width, [GpuArrayType(dtype=topgrad.dtype, return Apply(self, [kern, topgrad] + height_width, [GpuArrayType(dtype=topgrad.dtype,
...@@ -1057,12 +1085,14 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM): ...@@ -1057,12 +1085,14 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM):
bottom = gpu_contiguous(bottom) bottom = gpu_contiguous(bottom)
d_weights = GpuCorrMM_gradWeights(self.border_mode, d_weights = GpuCorrMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, self.filter_dilation,
self.num_groups)(bottom,
top, top,
weights.shape[-2:]) weights.shape[-2:])
d_top = GpuCorrMM(self.border_mode, d_top = GpuCorrMM(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
d_height_width = ( d_height_width = (
theano.gradient.DisconnectedType()(), theano.gradient.DisconnectedType()(),
) * 2 if len(inp) == 4 else () ) * 2 if len(inp) == 4 else ()
......
...@@ -348,7 +348,8 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -348,7 +348,8 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
const size_t dilH = 1, const size_t dilH = 1,
const size_t dilW = 1, const size_t dilW = 1,
const size_t padH = 0, const size_t padH = 0,
const size_t padW = 0) const size_t padW = 0,
const size_t numgroups = 1)
{ {
if (PyGpuArray_NDIM(bottom) != 4) if (PyGpuArray_NDIM(bottom) != 4)
{ {
...@@ -411,7 +412,7 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -411,7 +412,7 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
const size_t nFilters = PyGpuArray_DIMS(weight)[0]; const size_t nFilters = PyGpuArray_DIMS(weight)[0];
const size_t kH = PyGpuArray_DIMS(weight)[2]; const size_t kH = PyGpuArray_DIMS(weight)[2];
const size_t kW = PyGpuArray_DIMS(weight)[3]; const size_t kW = PyGpuArray_DIMS(weight)[3];
if (nChannels != PyGpuArray_DIMS(weight)[1]) { if (nChannels != (PyGpuArray_DIMS(weight)[1] * numgroups)) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"GpuCorrMM images and kernel must have the same stack size\n"); "GpuCorrMM images and kernel must have the same stack size\n");
return NULL; return NULL;
...@@ -469,11 +470,15 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -469,11 +470,15 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
} }
// Define some useful variables // Define some useful variables
const size_t bottom_stride = PyGpuArray_STRIDES(bottom)[0] / gpuarray_get_elsize(bottom->ga.typecode); const size_t batch_bottom_stride = PyGpuArray_STRIDES(bottom)[0] / gpuarray_get_elsize(bottom->ga.typecode);
const size_t top_stride = PyGpuArray_STRIDES(top)[0] / gpuarray_get_elsize(top->ga.typecode); const size_t batch_top_stride = PyGpuArray_STRIDES(top)[0] / gpuarray_get_elsize(top->ga.typecode);
const size_t K_ = col_dim[0]; const size_t group_bottom_stride = (PyGpuArray_STRIDES(bottom)[1] * nChannels / numgroups) / gpuarray_get_elsize(bottom->ga.typecode);
const size_t group_top_stride = (PyGpuArray_STRIDES(top)[1] * nFilters / numgroups) / gpuarray_get_elsize(top->ga.typecode);
const size_t group_weight_stride = (PyGpuArray_STRIDES(weight)[0] * nFilters / numgroups) / gpuarray_get_elsize(weight->ga.typecode);
const size_t K_ = col_dim[0] / numgroups;
const size_t N_ = col_dim[1]; const size_t N_ = col_dim[1];
const size_t M_ = nFilters; const size_t group_col_stride = (K_ * N_);
const size_t M_ = nFilters / numgroups;
PyGpuArrayObject *output; PyGpuArrayObject *output;
if (direction == 0) { // forward pass if (direction == 0) { // forward pass
...@@ -493,7 +498,7 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -493,7 +498,7 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// First, im2col // First, im2col
err = im2col(&bottom->ga, n * bottom_stride, err = im2col(&bottom->ga, n * batch_bottom_stride,
nChannels, bottomHeight, nChannels, bottomHeight,
bottomWidth, kH, kW, dilH, dilW, bottomWidth, kH, kW, dilH, dilW,
padH, padW, dH, dW, &col->ga); padH, padW, dH, dW, &col->ga);
...@@ -502,12 +507,14 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -502,12 +507,14 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
return NULL; return NULL;
} }
// Second, gemm // Second, gemm
for (size_t g = 0; g < numgroups; g++){
err = rgemm(cb_fortran, cb_no_trans, cb_no_trans, err = rgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1, N_, M_, K_, 1,
&col->ga, 0, N_, &col->ga, g * group_col_stride, N_,
&weight->ga, 0, K_, &weight->ga, g * group_weight_stride, K_,
0, 0,
&top->ga, n * top_stride, N_); &top->ga, n * batch_top_stride + g * group_top_stride, N_);
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM forward encountered an error running gemm: %d", err); "GpuCorrMM forward encountered an error running gemm: %d", err);
...@@ -533,7 +540,7 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -533,7 +540,7 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// First, im2col // First, im2col
err = im2col(&bottom->ga, n * bottom_stride, err = im2col(&bottom->ga, n * batch_bottom_stride,
nChannels, bottomHeight, nChannels, bottomHeight,
bottomWidth, kH, kW, dilH, dilW, bottomWidth, kH, kW, dilH, dilW,
padH, padW, dH, dW, &col->ga); padH, padW, dH, dW, &col->ga);
...@@ -545,12 +552,14 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -545,12 +552,14 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.) // is faster than setting weight to all zeros before the loop.)
for(size_t g = 0; g < numgroups; g++){
err = rgemm(cb_fortran, cb_trans, cb_no_trans, err = rgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1, K_, M_, N_, 1,
&col->ga, 0, N_, &col->ga, g * group_col_stride, N_,
&top->ga, n * top_stride, N_, &top->ga, n * batch_top_stride + g * group_top_stride, N_,
(n == 0) ? 0 : 1, (n == 0) ? 0 : 1,
&weight->ga, 0, K_); &weight->ga, g * group_weight_stride, K_);
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad weights encountered an error running gemm: %d", err); "GpuCorrMM grad weights encountered an error running gemm: %d", err);
...@@ -576,12 +585,14 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -576,12 +585,14 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// gemm into columns // gemm into columns
for(size_t g = 0; g < numgroups; g++){
err = rgemm(cb_fortran, cb_no_trans, cb_trans, err = rgemm(cb_fortran, cb_no_trans, cb_trans,
N_, K_, M_, 1, N_, K_, M_, 1,
&top->ga, n * top_stride, N_, &top->ga, n * batch_top_stride + g * group_top_stride, N_,
&weight->ga, 0, K_, &weight->ga, g * group_weight_stride, K_,
0, 0,
&col->ga, 0, N_); &col->ga, g * group_col_stride, N_);
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad inputs encountered an error running gemm: %d", err); "GpuCorrMM grad inputs encountered an error running gemm: %d", err);
...@@ -591,7 +602,7 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom, ...@@ -591,7 +602,7 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
// col2im back to the data // col2im back to the data
err = col2im(&col->ga, nChannels, bottomHeight, bottomWidth, err = col2im(&col->ga, nChannels, bottomHeight, bottomWidth,
kH, kW, dilH, dilW, padH, padW, kH, kW, dilH, dilW, padH, padW,
dH, dW, &bottom->ga, n * bottom_stride); dH, dW, &bottom->ga, n * batch_bottom_stride);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
......
...@@ -503,18 +503,22 @@ class GpuDnnConv(DnnBase): ...@@ -503,18 +503,22 @@ class GpuDnnConv(DnnBase):
algo : {'small', 'none', 'large', 'fft', 'fft_tiling', 'winograd', 'guess_once', algo : {'small', 'none', 'large', 'fft', 'fft_tiling', 'winograd', 'guess_once',
'guess_on_shape_change', 'time_once', 'time_on_shape_change'} 'guess_on_shape_change', 'time_once', 'time_on_shape_change'}
Default is the value of :attr:`config.dnn.conv.algo_fwd`. Default is the value of :attr:`config.dnn.conv.algo_fwd`.
num_groups :
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
""" """
_f16_ok = True _f16_ok = True
__props__ = ('algo', 'inplace') __props__ = ('algo', 'inplace', 'num_groups')
check_input = False check_input = False
params_type = ParamsType(conv_algo=cudnn.cudnnConvolutionFwdAlgo_t, params_type = ParamsType(conv_algo=cudnn.cudnnConvolutionFwdAlgo_t,
choose_algo=bool_t, choose_once=bool_t, choose_time=bool_t, choose_algo=bool_t, choose_once=bool_t, choose_time=bool_t,
inplace=bool_t, inplace=bool_t,
handle=handle_type) handle=handle_type,
num_groups=int_t)
def __init__(self, algo=None, inplace=False): def __init__(self, algo=None, inplace=False, num_groups=1):
DnnBase.__init__(self, ["dnn_conv_base.c", "dnn_fwd.c"], DnnBase.__init__(self, ["dnn_conv_base.c", "dnn_fwd.c"],
"APPLY_SPECIFIC(conv_fwd)") "APPLY_SPECIFIC(conv_fwd)")
...@@ -534,6 +538,7 @@ class GpuDnnConv(DnnBase): ...@@ -534,6 +538,7 @@ class GpuDnnConv(DnnBase):
self.choose_algo = self.algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME self.choose_algo = self.algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME
self.choose_once = self.algo in DNN_CONV_ALGO_CHOOSE_ONCE self.choose_once = self.algo in DNN_CONV_ALGO_CHOOSE_ONCE
self.choose_time = self.algo in DNN_CONV_ALGO_CHOOSE_TIME self.choose_time = self.algo in DNN_CONV_ALGO_CHOOSE_TIME
self.num_groups = num_groups
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
...@@ -544,6 +549,8 @@ class GpuDnnConv(DnnBase): ...@@ -544,6 +549,8 @@ class GpuDnnConv(DnnBase):
self.algo = config.dnn.conv.algo_fwd self.algo = config.dnn.conv.algo_fwd
if not hasattr(self, 'inplace'): if not hasattr(self, 'inplace'):
self.inplace = False self.inplace = False
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def make_node(self, img, kern, output, desc, alpha=None, beta=None): def make_node(self, img, kern, output, desc, alpha=None, beta=None):
ctx_name = infer_context_name(img, kern, output) ctx_name = infer_context_name(img, kern, output)
...@@ -567,6 +574,8 @@ class GpuDnnConv(DnnBase): ...@@ -567,6 +574,8 @@ class GpuDnnConv(DnnBase):
SUPPORTED_DNN_CONV_ALGO_RUNTIME): SUPPORTED_DNN_CONV_ALGO_RUNTIME):
raise ValueError("convolution algo %s can't be used for " raise ValueError("convolution algo %s can't be used for "
"3d convolutions", (self.algo,)) "3d convolutions", (self.algo,))
if img.type.ndim == 5 and self.num_groups != 1:
raise ValueError("Grouped convolutions not implemented for 3D convolutions")
if (not isinstance(desc.type, CDataType) or if (not isinstance(desc.type, CDataType) or
desc.type.ctype != 'cudnnConvolutionDescriptor_t'): desc.type.ctype != 'cudnnConvolutionDescriptor_t'):
...@@ -584,8 +593,8 @@ class GpuDnnConv(DnnBase): ...@@ -584,8 +593,8 @@ class GpuDnnConv(DnnBase):
top = gpu_contiguous(top) top = gpu_contiguous(top)
d_img = GpuDnnConvGradI()(kerns, top, empty_like(img), desc) d_img = GpuDnnConvGradI(num_groups=self.num_groups)(kerns, top, empty_like(img), desc)
d_kerns = GpuDnnConvGradW()(img, top, empty_like(kerns), desc) d_kerns = GpuDnnConvGradW(num_groups=self.num_groups)(img, top, empty_like(kerns), desc)
d_alpha = grad_not_implemented(self, 4, alpha) d_alpha = grad_not_implemented(self, 4, alpha)
d_beta = grad_not_implemented(self, 5, beta) d_beta = grad_not_implemented(self, 5, beta)
...@@ -637,18 +646,22 @@ class GpuDnnConvGradW(DnnBase): ...@@ -637,18 +646,22 @@ class GpuDnnConvGradW(DnnBase):
algo : {'none', 'deterministic', 'fft', 'small', 'guess_once', algo : {'none', 'deterministic', 'fft', 'small', 'guess_once',
'guess_on_shape_change', 'time_once', 'time_on_shape_change'} 'guess_on_shape_change', 'time_once', 'time_on_shape_change'}
Default is the value of :attr:`config.dnn.conv.algo_bwd_filter`. Default is the value of :attr:`config.dnn.conv.algo_bwd_filter`.
num_groups :
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
""" """
_f16_ok = True _f16_ok = True
__props__ = ('algo', 'inplace') __props__ = ('algo', 'inplace', 'num_groups')
check_input = False check_input = False
params_type = ParamsType(conv_algo=cudnn.cudnnConvolutionBwdFilterAlgo_t, params_type = ParamsType(conv_algo=cudnn.cudnnConvolutionBwdFilterAlgo_t,
choose_algo=bool_t, choose_once=bool_t, choose_time=bool_t, choose_algo=bool_t, choose_once=bool_t, choose_time=bool_t,
inplace=bool_t, inplace=bool_t,
handle=handle_type) handle=handle_type,
num_groups=int_t)
def __init__(self, inplace=False, algo=None): def __init__(self, inplace=False, algo=None, num_groups=1):
DnnBase.__init__(self, ["dnn_conv_base.c", "dnn_gw.c"], DnnBase.__init__(self, ["dnn_conv_base.c", "dnn_gw.c"],
"APPLY_SPECIFIC(conv_gw)") "APPLY_SPECIFIC(conv_gw)")
self.inplace = bool(inplace) self.inplace = bool(inplace)
...@@ -666,6 +679,7 @@ class GpuDnnConvGradW(DnnBase): ...@@ -666,6 +679,7 @@ class GpuDnnConvGradW(DnnBase):
self.choose_algo = self.algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME self.choose_algo = self.algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME
self.choose_once = self.algo in DNN_CONV_ALGO_CHOOSE_ONCE self.choose_once = self.algo in DNN_CONV_ALGO_CHOOSE_ONCE
self.choose_time = self.algo in DNN_CONV_ALGO_CHOOSE_TIME self.choose_time = self.algo in DNN_CONV_ALGO_CHOOSE_TIME
self.num_groups = num_groups
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
...@@ -673,6 +687,8 @@ class GpuDnnConvGradW(DnnBase): ...@@ -673,6 +687,8 @@ class GpuDnnConvGradW(DnnBase):
self.inplace = False self.inplace = False
if not hasattr(self, 'algo'): if not hasattr(self, 'algo'):
self.algo = config.dnn.conv.algo_bwd_filter self.algo = config.dnn.conv.algo_bwd_filter
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def grad(self, inp, grads): def grad(self, inp, grads):
img, top, output, desc, alpha, beta = inp img, top, output, desc, alpha, beta = inp
...@@ -680,8 +696,8 @@ class GpuDnnConvGradW(DnnBase): ...@@ -680,8 +696,8 @@ class GpuDnnConvGradW(DnnBase):
kerns = gpu_contiguous(kerns) kerns = gpu_contiguous(kerns)
d_img = GpuDnnConvGradI()(kerns, top, empty_like(img), desc) d_img = GpuDnnConvGradI(num_groups=self.num_groups)(kerns, top, empty_like(img), desc)
d_top = GpuDnnConv()(img, kerns, empty_like(top), desc) d_top = GpuDnnConv(num_groups=self.num_groups)(img, kerns, empty_like(top), desc)
d_alpha = grad_not_implemented(self, 4, alpha) d_alpha = grad_not_implemented(self, 4, alpha)
d_beta = grad_not_implemented(self, 5, beta) d_beta = grad_not_implemented(self, 5, beta)
...@@ -766,18 +782,22 @@ class GpuDnnConvGradI(DnnBase): ...@@ -766,18 +782,22 @@ class GpuDnnConvGradI(DnnBase):
algo : {'none', 'deterministic', 'fft', 'fft_tiling', 'winograd', 'guess_once', algo : {'none', 'deterministic', 'fft', 'fft_tiling', 'winograd', 'guess_once',
'guess_on_shape_change', 'time_once', 'time_on_shape_change'} 'guess_on_shape_change', 'time_once', 'time_on_shape_change'}
Default is the value of :attr:`config.dnn.conv.algo_bwd_data`. Default is the value of :attr:`config.dnn.conv.algo_bwd_data`.
num_groups :
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
""" """
_f16_ok = True _f16_ok = True
__props__ = ('algo', 'inplace',) __props__ = ('algo', 'inplace', 'num_groups')
check_input = False check_input = False
params_type = ParamsType(conv_algo=cudnn.cudnnConvolutionBwdDataAlgo_t, params_type = ParamsType(conv_algo=cudnn.cudnnConvolutionBwdDataAlgo_t,
choose_algo=bool_t, choose_once=bool_t, choose_time=bool_t, choose_algo=bool_t, choose_once=bool_t, choose_time=bool_t,
inplace=bool_t, inplace=bool_t,
handle=handle_type) handle=handle_type,
num_groups=int_t)
def __init__(self, inplace=False, algo=None): def __init__(self, inplace=False, algo=None, num_groups=1):
DnnBase.__init__(self, ["dnn_conv_base.c", "dnn_gi.c"], DnnBase.__init__(self, ["dnn_conv_base.c", "dnn_gi.c"],
"APPLY_SPECIFIC(conv_gi)") "APPLY_SPECIFIC(conv_gi)")
self.inplace = bool(inplace) self.inplace = bool(inplace)
...@@ -795,6 +815,7 @@ class GpuDnnConvGradI(DnnBase): ...@@ -795,6 +815,7 @@ class GpuDnnConvGradI(DnnBase):
self.choose_algo = self.algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME self.choose_algo = self.algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME
self.choose_once = self.algo in DNN_CONV_ALGO_CHOOSE_ONCE self.choose_once = self.algo in DNN_CONV_ALGO_CHOOSE_ONCE
self.choose_time = self.algo in DNN_CONV_ALGO_CHOOSE_TIME self.choose_time = self.algo in DNN_CONV_ALGO_CHOOSE_TIME
self.num_groups = num_groups
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
...@@ -802,6 +823,8 @@ class GpuDnnConvGradI(DnnBase): ...@@ -802,6 +823,8 @@ class GpuDnnConvGradI(DnnBase):
self.algo = config.dnn.conv.algo_bwd_data self.algo = config.dnn.conv.algo_bwd_data
if not hasattr(self, 'inplace'): if not hasattr(self, 'inplace'):
self.inplace = False self.inplace = False
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def grad(self, inp, grads): def grad(self, inp, grads):
kerns, top, output, desc, alpha, beta = inp kerns, top, output, desc, alpha, beta = inp
...@@ -809,8 +832,8 @@ class GpuDnnConvGradI(DnnBase): ...@@ -809,8 +832,8 @@ class GpuDnnConvGradI(DnnBase):
img = gpu_contiguous(img) img = gpu_contiguous(img)
d_kerns = GpuDnnConvGradW()(img, top, empty_like(kerns), desc) d_kerns = GpuDnnConvGradW(num_groups=self.num_groups)(img, top, empty_like(kerns), desc)
d_top = GpuDnnConv()(img, kerns, empty_like(top), desc) d_top = GpuDnnConv(num_groups=self.num_groups)(img, kerns, empty_like(top), desc)
d_alpha = grad_not_implemented(self, 4, alpha) d_alpha = grad_not_implemented(self, 4, alpha)
d_beta = grad_not_implemented(self, 5, beta) d_beta = grad_not_implemented(self, 5, beta)
...@@ -859,7 +882,7 @@ class GpuDnnConvGradI(DnnBase): ...@@ -859,7 +882,7 @@ class GpuDnnConvGradI(DnnBase):
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1), def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
conv_mode='conv', direction_hint=None, workmem=None, conv_mode='conv', direction_hint=None, workmem=None,
algo=None, precision=None): algo=None, precision=None, num_groups=1):
""" """
GPU convolution using cuDNN from NVIDIA. GPU convolution using cuDNN from NVIDIA.
...@@ -902,6 +925,9 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1), ...@@ -902,6 +925,9 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
should be done. Possible values are 'as_input', 'float16', 'float32' should be done. Possible values are 'as_input', 'float16', 'float32'
and 'float64'. Default is the value of and 'float64'. Default is the value of
:attr:`config.dnn.conv.precision`. :attr:`config.dnn.conv.precision`.
num_groups :
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
.. warning:: The cuDNN library only works with GPUs that have a compute .. warning:: The cuDNN library only works with GPUs that have a compute
...@@ -977,7 +1003,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1), ...@@ -977,7 +1003,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
filter_dilation=dilation) filter_dilation=dilation)
out_shp = assert_conv_shape(out_shp) out_shp = assert_conv_shape(out_shp)
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp) out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
return GpuDnnConv(algo=algo)(img, kerns, out, desc) return GpuDnnConv(algo=algo, num_groups=num_groups)(img, kerns, out, desc)
def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1, 1, 1), def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1, 1, 1),
...@@ -1101,7 +1127,8 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1 ...@@ -1101,7 +1127,8 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
def dnn_gradweight(img, topgrad, kerns_shp, border_mode='valid', def dnn_gradweight(img, topgrad, kerns_shp, border_mode='valid',
subsample=(1, 1), dilation=(1, 1), conv_mode='conv', precision=None): subsample=(1, 1), dilation=(1, 1), conv_mode='conv',
precision=None, algo=None, num_groups=1):
""" """
TODO: document this TODO: document this
""" """
...@@ -1116,7 +1143,7 @@ def dnn_gradweight(img, topgrad, kerns_shp, border_mode='valid', ...@@ -1116,7 +1143,7 @@ def dnn_gradweight(img, topgrad, kerns_shp, border_mode='valid',
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation, desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation,
conv_mode=conv_mode, precision=precision)(kerns_shp) conv_mode=conv_mode, precision=precision)(kerns_shp)
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*kerns_shp) out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*kerns_shp)
return GpuDnnConvGradW()(img, topgrad, out, desc) return GpuDnnConvGradW(algo=algo, num_groups=num_groups)(img, topgrad, out, desc)
def dnn_gradweight3d(img, topgrad, kerns_shp, border_mode='valid', def dnn_gradweight3d(img, topgrad, kerns_shp, border_mode='valid',
...@@ -1129,7 +1156,8 @@ def dnn_gradweight3d(img, topgrad, kerns_shp, border_mode='valid', ...@@ -1129,7 +1156,8 @@ def dnn_gradweight3d(img, topgrad, kerns_shp, border_mode='valid',
def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid', def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid',
subsample=(1, 1), dilation=(1, 1), conv_mode='conv', precision=None): subsample=(1, 1), dilation=(1, 1), conv_mode='conv',
precision=None, algo=None, num_groups=1):
""" """
TODO: document this TODO: document this
""" """
...@@ -1144,7 +1172,7 @@ def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid', ...@@ -1144,7 +1172,7 @@ def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid',
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation, desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation,
conv_mode=conv_mode, precision=precision)(kerns.shape) conv_mode=conv_mode, precision=precision)(kerns.shape)
out = GpuAllocEmpty(dtype=kerns.dtype, context_name=ctx_name)(*img_shp) out = GpuAllocEmpty(dtype=kerns.dtype, context_name=ctx_name)(*img_shp)
return GpuDnnConvGradI()(kerns, topgrad, out, desc) return GpuDnnConvGradI(algo=algo, num_groups=num_groups)(kerns, topgrad, out, desc)
def dnn_gradinput3d(kerns, topgrad, img_shp, border_mode='valid', def dnn_gradinput3d(kerns, topgrad, img_shp, border_mode='valid',
...@@ -2736,7 +2764,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs): ...@@ -2736,7 +2764,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
subsample=op.subsample, subsample=op.subsample,
dilation=op.filter_dilation, dilation=op.filter_dilation,
direction_hint='forward!', direction_hint='forward!',
conv_mode=conv_mode) conv_mode=conv_mode,
num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradWeights): elif isinstance(op, AbstractConv2d_gradWeights):
shape = (inp2.shape[1], inp1.shape[1], shape = (inp2.shape[1], inp1.shape[1],
inputs[2][0], inputs[2][1]) inputs[2][0], inputs[2][1])
...@@ -2744,7 +2773,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs): ...@@ -2744,7 +2773,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
border_mode=op.border_mode, border_mode=op.border_mode,
subsample=op.subsample, subsample=op.subsample,
dilation=op.filter_dilation, dilation=op.filter_dilation,
conv_mode=conv_mode) conv_mode=conv_mode,
num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradInputs): elif isinstance(op, AbstractConv2d_gradInputs):
shape = (inp2.shape[0], inp1.shape[1], shape = (inp2.shape[0], inp1.shape[1],
inputs[2][0], inputs[2][1]) inputs[2][0], inputs[2][1])
...@@ -2752,7 +2782,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs): ...@@ -2752,7 +2782,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
border_mode=op.border_mode, border_mode=op.border_mode,
subsample=op.subsample, subsample=op.subsample,
dilation=op.filter_dilation, dilation=op.filter_dilation,
conv_mode=conv_mode) conv_mode=conv_mode,
num_groups=op.num_groups)
return [rval] return [rval]
...@@ -2837,17 +2868,17 @@ def local_abstractconv_gi_cudnn(node): ...@@ -2837,17 +2868,17 @@ def local_abstractconv_gi_cudnn(node):
@inplace_allocempty(GpuDnnConv, 2) @inplace_allocempty(GpuDnnConv, 2)
def local_dnn_conv_inplace(node, inputs): def local_dnn_conv_inplace(node, inputs):
return [GpuDnnConv(algo=node.op.algo, inplace=True)(*inputs)] return [GpuDnnConv(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
@inplace_allocempty(GpuDnnConvGradW, 2) @inplace_allocempty(GpuDnnConvGradW, 2)
def local_dnn_convgw_inplace(node, inputs): def local_dnn_convgw_inplace(node, inputs):
return [GpuDnnConvGradW(algo=node.op.algo, inplace=True)(*inputs)] return [GpuDnnConvGradW(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
@inplace_allocempty(GpuDnnConvGradI, 2) @inplace_allocempty(GpuDnnConvGradI, 2)
def local_dnn_convgi_inplace(node, inputs): def local_dnn_convgi_inplace(node, inputs):
return [GpuDnnConvGradI(algo=node.op.algo, inplace=True)(*inputs)] return [GpuDnnConvGradI(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
optdb.register('local_dnna_conv_inplace', optdb.register('local_dnna_conv_inplace',
tensor.opt.in2out(local_dnn_conv_inplace, tensor.opt.in2out(local_dnn_conv_inplace,
...@@ -2860,19 +2891,19 @@ optdb.register('local_dnna_conv_inplace', ...@@ -2860,19 +2891,19 @@ optdb.register('local_dnna_conv_inplace',
@register_opt('cudnn') @register_opt('cudnn')
@alpha_merge(GpuDnnConv, alpha_in=4, beta_in=5) @alpha_merge(GpuDnnConv, alpha_in=4, beta_in=5)
def local_dnn_conv_alpha_merge(node, *inputs): def local_dnn_conv_alpha_merge(node, *inputs):
return [GpuDnnConv(algo=node.op.algo)(*inputs)] return [GpuDnnConv(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@alpha_merge(GpuDnnConvGradW, alpha_in=4, beta_in=5) @alpha_merge(GpuDnnConvGradW, alpha_in=4, beta_in=5)
def local_dnn_convw_alpha_merge(node, *inputs): def local_dnn_convw_alpha_merge(node, *inputs):
return [GpuDnnConvGradW(algo=node.op.algo)(*inputs)] return [GpuDnnConvGradW(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@alpha_merge(GpuDnnConvGradI, alpha_in=4, beta_in=5) @alpha_merge(GpuDnnConvGradI, alpha_in=4, beta_in=5)
def local_dnn_convi_alpha_merge(node, *inputs): def local_dnn_convi_alpha_merge(node, *inputs):
return [GpuDnnConvGradI(algo=node.op.algo)(*inputs)] return [GpuDnnConvGradI(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
......
#section support_code #section support_code
static int static int
c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) { c_set_tensor_for_conv(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc, size_t groups) {
cudnnDataType_t dt; cudnnDataType_t dt;
size_t ds; size_t ds;
switch (var->ga.typecode) { switch (var->ga.typecode) {
...@@ -42,7 +42,8 @@ c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) { ...@@ -42,7 +42,8 @@ c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
strs[i] = 1; strs[i] = 1;
dims[i] = 1; dims[i] = 1;
} }
//only for grouped convolution i.e when groups > 1
dims[1] = dims[1] / groups;
cudnnStatus_t err = cudnnSetTensorNdDescriptor(desc, dt, nd < 3 ? 3 : nd, cudnnStatus_t err = cudnnSetTensorNdDescriptor(desc, dt, nd < 3 ? 3 : nd,
dims, strs); dims, strs);
if (err != CUDNN_STATUS_SUCCESS) { if (err != CUDNN_STATUS_SUCCESS) {
...@@ -54,6 +55,11 @@ c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) { ...@@ -54,6 +55,11 @@ c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
return 0; return 0;
} }
static int
c_set_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
return c_set_tensor_for_conv(var, desc, 1);
}
static int c_make_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t *desc) { static int c_make_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t *desc) {
cudnnStatus_t err; cudnnStatus_t err;
err = cudnnCreateTensorDescriptor(desc); err = cudnnCreateTensorDescriptor(desc);
...@@ -71,7 +77,7 @@ static int c_make_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t *desc) ...@@ -71,7 +77,7 @@ static int c_make_tensorNd(PyGpuArrayObject *var, cudnnTensorDescriptor_t *desc)
} }
static int static int
c_set_filter(PyGpuArrayObject *var, cudnnFilterDescriptor_t desc) { c_set_filter(PyGpuArrayObject *var, cudnnFilterDescriptor_t desc, size_t groups) {
cudnnDataType_t dt; cudnnDataType_t dt;
cudnnStatus_t err; cudnnStatus_t err;
...@@ -111,6 +117,7 @@ c_set_filter(PyGpuArrayObject *var, cudnnFilterDescriptor_t desc) { ...@@ -111,6 +117,7 @@ c_set_filter(PyGpuArrayObject *var, cudnnFilterDescriptor_t desc) {
/* Filters can't be less than 3d so we pad */ /* Filters can't be less than 3d so we pad */
for (unsigned int i = nd; i < 3; i++) for (unsigned int i = nd; i < 3; i++)
dims[i] = 1; dims[i] = 1;
dims[0] = dims[0] / groups;
if (nd < 3) if (nd < 3)
nd = 3; nd = 3;
...@@ -135,7 +142,7 @@ static int c_make_filter(PyGpuArrayObject *var, cudnnFilterDescriptor_t *desc) { ...@@ -135,7 +142,7 @@ static int c_make_filter(PyGpuArrayObject *var, cudnnFilterDescriptor_t *desc) {
cudnnGetErrorString(err)); cudnnGetErrorString(err));
return -1; return -1;
} }
if (c_set_filter(var, *desc) != 0) { if (c_set_filter(var, *desc, 1) != 0) {
cudnnDestroyFilterDescriptor(*desc); cudnnDestroyFilterDescriptor(*desc);
return -1; return -1;
} }
......
...@@ -29,7 +29,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns, ...@@ -29,7 +29,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
float af = alpha, bf = beta; float af = alpha, bf = beta;
cudnnStatus_t err = CUDNN_STATUS_SUCCESS; cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
if (PyGpuArray_DIMS(input)[1] != PyGpuArray_DIMS(kerns)[1]) { if (PyGpuArray_DIMS(input)[1] != PyGpuArray_DIMS(kerns)[1] * params->num_groups) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"images and kernel must have the same stack size"); "images and kernel must have the same stack size");
return 1; return 1;
...@@ -72,12 +72,15 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns, ...@@ -72,12 +72,15 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
return 0; return 0;
} }
if (c_set_tensorNd(input, APPLY_SPECIFIC(input)) == -1) if (c_set_tensor_for_conv(input, APPLY_SPECIFIC(input), params->num_groups) == -1)
return 1; return 1;
if (c_set_filter(kerns, APPLY_SPECIFIC(kerns)) == -1) if (c_set_filter(kerns, APPLY_SPECIFIC(kerns), params->num_groups) == -1)
return 1; return 1;
if (c_set_tensorNd(*output, APPLY_SPECIFIC(output)) == -1) if (c_set_tensor_for_conv(*output, APPLY_SPECIFIC(output), params->num_groups) == -1)
return 1; return 1;
size_t input_offset = PyGpuArray_STRIDE(input, 0) / params->num_groups;
size_t kern_offset = PyGpuArray_STRIDE(kerns, 0) * PyGpuArray_DIM(kerns, 0) / params->num_groups;
size_t output_offset = PyGpuArray_STRIDE(*output, 0) / params->num_groups;
cudnnConvolutionFwdAlgo_t algo = params->conv_algo; cudnnConvolutionFwdAlgo_t algo = params->conv_algo;
...@@ -281,15 +284,17 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns, ...@@ -281,15 +284,17 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
cuda_wait(kerns->ga.data, GPUARRAY_CUDA_WAIT_READ); cuda_wait(kerns->ga.data, GPUARRAY_CUDA_WAIT_READ);
cuda_wait((*output)->ga.data, GPUARRAY_CUDA_WAIT_WRITE); cuda_wait((*output)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
for ( int g = 0; g < params->num_groups; g++) {
err = cudnnConvolutionForward( err = cudnnConvolutionForward(
params->handle, params->handle,
alpha_p, alpha_p,
APPLY_SPECIFIC(input), PyGpuArray_DEV_DATA(input), APPLY_SPECIFIC(input), PyGpuArray_DEV_DATA(input) + input_offset * g,
APPLY_SPECIFIC(kerns), PyGpuArray_DEV_DATA(kerns), APPLY_SPECIFIC(kerns), PyGpuArray_DEV_DATA(kerns) + kern_offset * g,
desc, algo, desc, algo,
worksize == 0 ? NULL : *(void **)workspace, worksize, worksize == 0 ? NULL : *(void **)workspace, worksize,
beta_p, beta_p,
APPLY_SPECIFIC(output), PyGpuArray_DEV_DATA(*output)); APPLY_SPECIFIC(output), PyGpuArray_DEV_DATA(*output) + output_offset * g);
}
if (worksize != 0) if (worksize != 0)
gpudata_release(workspace); gpudata_release(workspace);
......
...@@ -28,7 +28,7 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output, ...@@ -28,7 +28,7 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
float af = alpha, bf = beta; float af = alpha, bf = beta;
cudnnStatus_t err = CUDNN_STATUS_SUCCESS; cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
if (PyGpuArray_DIMS(im)[1] != PyGpuArray_DIMS(kerns)[1]) { if (PyGpuArray_DIMS(im)[1] != PyGpuArray_DIMS(kerns)[1] * params->num_groups) {
PyErr_SetString(PyExc_ValueError, "images and kernel must have the same " PyErr_SetString(PyExc_ValueError, "images and kernel must have the same "
"stack size"); "stack size");
return 1; return 1;
...@@ -71,12 +71,15 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output, ...@@ -71,12 +71,15 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
return 0; return 0;
} }
if (c_set_tensorNd(output, APPLY_SPECIFIC(output)) == -1) if (c_set_tensor_for_conv(output, APPLY_SPECIFIC(output), params->num_groups) == -1)
return 1; return 1;
if (c_set_filter(kerns, APPLY_SPECIFIC(kerns)) == -1) if (c_set_filter(kerns, APPLY_SPECIFIC(kerns), params->num_groups) == -1)
return 1; return 1;
if (c_set_tensorNd(*input, APPLY_SPECIFIC(input)) == -1) if (c_set_tensor_for_conv(*input, APPLY_SPECIFIC(input), params->num_groups) == -1)
return 1; return 1;
size_t input_offset = PyGpuArray_STRIDE(*input, 0) / params->num_groups;
size_t kern_offset = PyGpuArray_STRIDE(kerns, 0) * PyGpuArray_DIM(kerns, 0) / params->num_groups;
size_t output_offset = PyGpuArray_STRIDE(output, 0) / params->num_groups;
cudnnConvolutionBwdDataAlgo_t algo = params->conv_algo; cudnnConvolutionBwdDataAlgo_t algo = params->conv_algo;
...@@ -93,7 +96,7 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output, ...@@ -93,7 +96,7 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
} }
if (PyGpuArray_NDIM(im) == 4) { if (PyGpuArray_NDIM(im) == 4) {
if ((PyGpuArray_DIMS(output)[0] != expected_output_dims[0]) || if ((PyGpuArray_DIMS(output)[0] != expected_output_dims[0]) ||
(PyGpuArray_DIMS(output)[1] != expected_output_dims[1]) || (PyGpuArray_DIMS(output)[1] / params->num_groups != expected_output_dims[1]) ||
(PyGpuArray_DIMS(output)[2] != expected_output_dims[2]) || (PyGpuArray_DIMS(output)[2] != expected_output_dims[2]) ||
(PyGpuArray_DIMS(output)[3] != expected_output_dims[3])) { (PyGpuArray_DIMS(output)[3] != expected_output_dims[3])) {
PyErr_Format(PyExc_ValueError, "impossible convolution output dim: expected %ldx%ldx%ldx%ld" PyErr_Format(PyExc_ValueError, "impossible convolution output dim: expected %ldx%ldx%ldx%ld"
...@@ -286,14 +289,17 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output, ...@@ -286,14 +289,17 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
cuda_wait(output->ga.data, GPUARRAY_CUDA_WAIT_READ); cuda_wait(output->ga.data, GPUARRAY_CUDA_WAIT_READ);
cuda_wait((*input)->ga.data, GPUARRAY_CUDA_WAIT_WRITE); cuda_wait((*input)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
for ( int g = 0; g < params->num_groups; g++)
{
err = cudnnConvolutionBackwardData( err = cudnnConvolutionBackwardData(
params->handle, params->handle,
alpha_p, alpha_p,
APPLY_SPECIFIC(kerns), PyGpuArray_DEV_DATA(kerns), APPLY_SPECIFIC(kerns), PyGpuArray_DEV_DATA(kerns) + kern_offset * g,
APPLY_SPECIFIC(output), PyGpuArray_DEV_DATA(output), APPLY_SPECIFIC(output), PyGpuArray_DEV_DATA(output) + output_offset * g,
desc, algo, worksize == 0 ? NULL : *(void **)workspace, worksize, desc, algo, worksize == 0 ? NULL : *(void **)workspace, worksize,
beta_p, beta_p,
APPLY_SPECIFIC(input), PyGpuArray_DEV_DATA(*input)); APPLY_SPECIFIC(input), PyGpuArray_DEV_DATA(*input) + input_offset * g);
}
if (worksize != 0) if (worksize != 0)
gpudata_release(workspace); gpudata_release(workspace);
......
...@@ -28,7 +28,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output, ...@@ -28,7 +28,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
float af = alpha, bf = beta; float af = alpha, bf = beta;
cudnnStatus_t err = CUDNN_STATUS_SUCCESS; cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
if (PyGpuArray_DIMS(input)[1] != PyGpuArray_DIMS(km)[1]) { if (PyGpuArray_DIMS(input)[1] != PyGpuArray_DIMS(km)[1] * params->num_groups) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"GpuDnnConv images and kernel must have the same stack size"); "GpuDnnConv images and kernel must have the same stack size");
return 1; return 1;
...@@ -71,13 +71,17 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output, ...@@ -71,13 +71,17 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
return 0; return 0;
} }
if (c_set_tensorNd(input, APPLY_SPECIFIC(input)) == -1) if (c_set_tensor_for_conv(input, APPLY_SPECIFIC(input), params->num_groups) == -1)
return 1; return 1;
if (c_set_tensorNd(output, APPLY_SPECIFIC(output)) == -1) if (c_set_tensor_for_conv(output, APPLY_SPECIFIC(output), params->num_groups) == -1)
return 1; return 1;
if (c_set_filter(*kerns, APPLY_SPECIFIC(kerns)) == -1) if (c_set_filter(*kerns, APPLY_SPECIFIC(kerns), params->num_groups) == -1)
return 1; return 1;
size_t input_offset = PyGpuArray_STRIDE(input, 0) / params->num_groups;
size_t kern_offset = PyGpuArray_STRIDE(*kerns, 0) * PyGpuArray_DIM(*kerns, 0) / params->num_groups;
size_t output_offset = PyGpuArray_STRIDE(output, 0) / params->num_groups;
cudnnConvolutionBwdFilterAlgo_t algo = params->conv_algo; cudnnConvolutionBwdFilterAlgo_t algo = params->conv_algo;
cuda_enter(c->ctx); cuda_enter(c->ctx);
...@@ -93,7 +97,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output, ...@@ -93,7 +97,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
} }
if (PyGpuArray_NDIM(input) == 4) { if (PyGpuArray_NDIM(input) == 4) {
if ((PyGpuArray_DIMS(output)[0] != expected_output_dims[0]) || if ((PyGpuArray_DIMS(output)[0] != expected_output_dims[0]) ||
(PyGpuArray_DIMS(output)[1] != expected_output_dims[1]) || (PyGpuArray_DIMS(output)[1] / params->num_groups != expected_output_dims[1]) ||
(PyGpuArray_DIMS(output)[2] != expected_output_dims[2]) || (PyGpuArray_DIMS(output)[2] != expected_output_dims[2]) ||
(PyGpuArray_DIMS(output)[3] != expected_output_dims[3])) { (PyGpuArray_DIMS(output)[3] != expected_output_dims[3])) {
PyErr_Format(PyExc_ValueError, "impossible convolution output dim: expected %ldx%ldx%dx%ld" PyErr_Format(PyExc_ValueError, "impossible convolution output dim: expected %ldx%ldx%dx%ld"
...@@ -273,14 +277,18 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output, ...@@ -273,14 +277,18 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
cuda_wait(output->ga.data, GPUARRAY_CUDA_WAIT_READ); cuda_wait(output->ga.data, GPUARRAY_CUDA_WAIT_READ);
cuda_wait((*kerns)->ga.data, GPUARRAY_CUDA_WAIT_WRITE); cuda_wait((*kerns)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
for ( int g = 0; g < params->num_groups; g++)
{
err = cudnnConvolutionBackwardFilter( err = cudnnConvolutionBackwardFilter(
params->handle, params->handle,
alpha_p, alpha_p,
APPLY_SPECIFIC(input), PyGpuArray_DEV_DATA(input), APPLY_SPECIFIC(input), PyGpuArray_DEV_DATA(input) + input_offset * g ,
APPLY_SPECIFIC(output), PyGpuArray_DEV_DATA(output), APPLY_SPECIFIC(output), PyGpuArray_DEV_DATA(output) + output_offset * g,
desc, algo, worksize == 0 ? NULL : *(void **)workspace, worksize, desc, algo, worksize == 0 ? NULL : *(void **)workspace, worksize,
beta_p, beta_p,
APPLY_SPECIFIC(kerns), PyGpuArray_DEV_DATA(*kerns)); APPLY_SPECIFIC(kerns), PyGpuArray_DEV_DATA(*kerns) + kern_offset * g);
}
if (worksize != 0) if (worksize != 0)
gpudata_release(workspace); gpudata_release(workspace);
......
...@@ -1533,7 +1533,8 @@ def local_abstractconv_gemm(node): ...@@ -1533,7 +1533,8 @@ def local_abstractconv_gemm(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
if ((border_mode == 'full') and (subsample == (1, 1))):
if ((border_mode == 'full') and (subsample == (1, 1)) and node.op.num_groups == 1):
if not node.op.filter_flip: if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
# need to dimshuffle the kernel for full convolution # need to dimshuffle the kernel for full convolution
...@@ -1550,7 +1551,8 @@ def local_abstractconv_gemm(node): ...@@ -1550,7 +1551,8 @@ def local_abstractconv_gemm(node):
# By default use GpuCorrMM # By default use GpuCorrMM
rval = GpuCorrMM(border_mode, rval = GpuCorrMM(border_mode,
subsample, subsample,
filter_dilation)(gpu_contiguous(img), filter_dilation,
node.op.num_groups)(gpu_contiguous(img),
gpu_contiguous(kern)) gpu_contiguous(kern))
# call GpuCorrMM_gradWeights if good # call GpuCorrMM_gradWeights if good
...@@ -1669,7 +1671,8 @@ def local_abstractconv_gradweights_gemm(node): ...@@ -1669,7 +1671,8 @@ def local_abstractconv_gradweights_gemm(node):
rval = GpuCorrMM_gradWeights(border_mode=node.op.border_mode, rval = GpuCorrMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)( filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(
gpu_contiguous(img), gpu_contiguous(topgrad), shape) gpu_contiguous(img), gpu_contiguous(topgrad), shape)
if node.op.filter_flip: if node.op.filter_flip:
rval = rval[:, :, ::-1, ::-1] rval = rval[:, :, ::-1, ::-1]
...@@ -1713,7 +1716,8 @@ def local_abstractconv_gradinputs_gemm(node): ...@@ -1713,7 +1716,8 @@ def local_abstractconv_gradinputs_gemm(node):
rval = GpuCorrMM_gradInputs(border_mode=node.op.border_mode, rval = GpuCorrMM_gradInputs(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)( filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(
gpu_contiguous(kern), gpu_contiguous(topgrad), shape) gpu_contiguous(kern), gpu_contiguous(topgrad), shape)
return [rval] return [rval]
......
...@@ -25,6 +25,7 @@ from . import test_nnet ...@@ -25,6 +25,7 @@ from . import test_nnet
from .rnn_support import Model, GRU, LSTM, WrapperLayer from .rnn_support import Model, GRU, LSTM, WrapperLayer
from theano.configdefaults import SUPPORTED_DNN_CONV_ALGO_FWD from theano.configdefaults import SUPPORTED_DNN_CONV_ALGO_FWD
from theano.tensor.nnet.tests.test_abstract_conv import Grouped_conv_noOptim
try: try:
import pygpu import pygpu
...@@ -2263,3 +2264,37 @@ def test_dnn_rnn_lstm_grad_c(): ...@@ -2263,3 +2264,37 @@ def test_dnn_rnn_lstm_grad_c():
(i + 1) * len(cudnn_grads_layer)] (i + 1) * len(cudnn_grads_layer)]
for j, g in enumerate(cudnn_grads_layer): for j, g in enumerate(cudnn_grads_layer):
utt.assert_allclose(ref_grads_layer[j], g) utt.assert_allclose(ref_grads_layer[j], g)
def dconv2d(border_mode, subsample, filter_dilation, num_groups):
def dconv(img, kern):
return dnn.dnn_conv(img, kern, border_mode=border_mode, subsample=subsample, dilation=filter_dilation,
conv_mode='conv', direction_hint='forward', workmem=None,
algo=None, precision=None, num_groups=num_groups)
return dconv
def dconv2dw(border_mode, subsample, filter_dilation, num_groups):
def dconvw(img, topgrad, kshp):
return dnn.dnn_gradweight(img, topgrad, kshp, border_mode=border_mode, subsample=subsample, dilation=filter_dilation,
conv_mode='conv', precision=None, algo=None, num_groups=num_groups)
return dconvw
def dconv2di(border_mode, subsample, filter_dilation, num_groups):
def dconvi(kern, topgrad, imshp):
return dnn.dnn_gradinput(kern, topgrad, imshp, border_mode=border_mode, subsample=subsample, dilation=filter_dilation,
conv_mode='conv', precision=None, algo=None, num_groups=num_groups)
return dconvi
class Cudnn_grouped_conv(Grouped_conv_noOptim):
mode = mode_with_gpu
conv2d = staticmethod(dconv2d)
conv2d_gradw = staticmethod(dconv2dw)
conv2d_gradi = staticmethod(dconv2di)
conv2d_op = dnn.GpuDnnConv
conv2d_gradw_op = dnn.GpuDnnConvGradW
conv2d_gradi_op = dnn.GpuDnnConvGradI
flip_filter = False
is_dnn = True
...@@ -11,6 +11,7 @@ from theano.tensor.nnet.corr import CorrMM, CorrMM_gradWeights, CorrMM_gradInput ...@@ -11,6 +11,7 @@ from theano.tensor.nnet.corr import CorrMM, CorrMM_gradWeights, CorrMM_gradInput
from ..type import gpuarray_shared_constructor from ..type import gpuarray_shared_constructor
from ..blas import GpuCorrMM, GpuCorrMM_gradWeights, GpuCorrMM_gradInputs from ..blas import GpuCorrMM, GpuCorrMM_gradWeights, GpuCorrMM_gradInputs
from .config import mode_with_gpu, mode_without_gpu, ref_cast from .config import mode_with_gpu, mode_without_gpu, ref_cast
from theano.tensor.nnet.tests.test_abstract_conv import Grouped_conv_noOptim
class TestCorrMM(unittest.TestCase): class TestCorrMM(unittest.TestCase):
...@@ -219,3 +220,15 @@ class TestCorrMM(unittest.TestCase): ...@@ -219,3 +220,15 @@ class TestCorrMM(unittest.TestCase):
verify_grad=False) verify_grad=False)
self.run_gradinput(inputs_shape=(1, 1024, 3, 1), self.run_gradinput(inputs_shape=(1, 1024, 3, 1),
filters_shape=(1, 1, 1, 1024)) filters_shape=(1, 1, 1, 1024))
class TestGroupGpuCorr2d(Grouped_conv_noOptim):
mode = theano.compile.get_mode("FAST_RUN")
conv2d = GpuCorrMM
conv2d_gradw = GpuCorrMM_gradWeights
conv2d_gradi = GpuCorrMM_gradInputs
conv2d_op = GpuCorrMM
conv2d_gradw_op = GpuCorrMM_gradWeights
conv2d_gradi_op = GpuCorrMM_gradInputs
flip_filter = True
is_dnn = False
...@@ -39,7 +39,7 @@ from .abstract_conv import conv3d ...@@ -39,7 +39,7 @@ from .abstract_conv import conv3d
def conv2d(input, filters, input_shape=None, filter_shape=None, def conv2d(input, filters, input_shape=None, filter_shape=None,
border_mode='valid', subsample=(1, 1), filter_flip=True, border_mode='valid', subsample=(1, 1), filter_flip=True,
image_shape=None, filter_dilation=(1, 1), **kwargs): image_shape=None, filter_dilation=(1, 1), num_groups=1, **kwargs):
""" """
This function will build the symbolic graph for convolving a mini-batch of a This function will build the symbolic graph for convolving a mini-batch of a
stack of 2D inputs with a set of 2D filters. The implementation is modelled stack of 2D inputs with a set of 2D filters. The implementation is modelled
...@@ -103,6 +103,10 @@ def conv2d(input, filters, input_shape=None, filter_shape=None, ...@@ -103,6 +103,10 @@ def conv2d(input, filters, input_shape=None, filter_shape=None,
Factor by which to subsample (stride) the input. Factor by which to subsample (stride) the input.
Also called dilation elsewhere. Also called dilation elsewhere.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
kwargs: Any other keyword arguments are accepted for backwards kwargs: Any other keyword arguments are accepted for backwards
compatibility, but will be ignored. compatibility, but will be ignored.
...@@ -152,12 +156,12 @@ def conv2d(input, filters, input_shape=None, filter_shape=None, ...@@ -152,12 +156,12 @@ def conv2d(input, filters, input_shape=None, filter_shape=None,
return abstract_conv2d(input, filters, input_shape, filter_shape, return abstract_conv2d(input, filters, input_shape, filter_shape,
border_mode, subsample, filter_flip, border_mode, subsample, filter_flip,
filter_dilation) filter_dilation, num_groups)
def conv2d_transpose(input, filters, output_shape, filter_shape=None, def conv2d_transpose(input, filters, output_shape, filter_shape=None,
border_mode='valid', input_dilation=(1, 1), border_mode='valid', input_dilation=(1, 1),
filter_flip=True, filter_dilation=(1, 1)): filter_flip=True, filter_dilation=(1, 1), num_groups=1):
""" """
This function will build the symbolic graph for applying a transposed This function will build the symbolic graph for applying a transposed
convolution over a mini-batch of a stack of 2D inputs with a set of 2D convolution over a mini-batch of a stack of 2D inputs with a set of 2D
...@@ -209,6 +213,10 @@ def conv2d_transpose(input, filters, output_shape, filter_shape=None, ...@@ -209,6 +213,10 @@ def conv2d_transpose(input, filters, output_shape, filter_shape=None,
Factor by which to subsample (stride) the input. Factor by which to subsample (stride) the input.
Also called dilation elsewhere. Also called dilation elsewhere.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns Returns
------- -------
Symbolic 4D tensor Symbolic 4D tensor
...@@ -235,4 +243,5 @@ def conv2d_transpose(input, filters, output_shape, filter_shape=None, ...@@ -235,4 +243,5 @@ def conv2d_transpose(input, filters, output_shape, filter_shape=None,
border_mode=border_mode, border_mode=border_mode,
subsample=input_dilation, subsample=input_dilation,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
...@@ -66,7 +66,6 @@ def get_conv_output_shape(image_shape, kernel_shape, ...@@ -66,7 +66,6 @@ def get_conv_output_shape(image_shape, kernel_shape,
""" """
bsize, imshp = image_shape[0], image_shape[2:] bsize, imshp = image_shape[0], image_shape[2:]
nkern, kshp = kernel_shape[0], kernel_shape[2:] nkern, kshp = kernel_shape[0], kernel_shape[2:]
if filter_dilation is None: if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int') filter_dilation = np.ones(len(subsample), dtype='int')
...@@ -139,7 +138,8 @@ def get_conv_shape_1axis(image_shape, kernel_shape, border_mode, ...@@ -139,7 +138,8 @@ def get_conv_shape_1axis(image_shape, kernel_shape, border_mode,
def get_conv_gradweights_shape(image_shape, top_shape, def get_conv_gradweights_shape(image_shape, top_shape,
border_mode, subsample, border_mode, subsample,
filter_dilation=None): filter_dilation=None,
num_groups=1):
""" """
This function tries to compute the kernel shape of convolution gradWeights. This function tries to compute the kernel shape of convolution gradWeights.
...@@ -167,6 +167,8 @@ def get_conv_gradweights_shape(image_shape, top_shape, ...@@ -167,6 +167,8 @@ def get_conv_gradweights_shape(image_shape, top_shape,
filter_dilation: tuple of int (symbolic or numeric). Its two or three filter_dilation: tuple of int (symbolic or numeric). Its two or three
elements correspond respectively to the dilation on height and elements correspond respectively to the dilation on height and
width axis. width axis.
num_groups: An int which specifies the number of separate groups to
be divided into.
Returns Returns
------- -------
...@@ -181,6 +183,9 @@ def get_conv_gradweights_shape(image_shape, top_shape, ...@@ -181,6 +183,9 @@ def get_conv_gradweights_shape(image_shape, top_shape,
if filter_dilation is None: if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int') filter_dilation = np.ones(len(subsample), dtype='int')
if num_groups > 1:
assert len(subsample) == 2
nchan = nchan // num_groups
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
out_shp = tuple(get_conv_gradweights_shape_1axis( out_shp = tuple(get_conv_gradweights_shape_1axis(
...@@ -245,7 +250,8 @@ def get_conv_gradweights_shape_1axis(image_shape, top_shape, border_mode, ...@@ -245,7 +250,8 @@ def get_conv_gradweights_shape_1axis(image_shape, top_shape, border_mode,
def get_conv_gradinputs_shape(kernel_shape, top_shape, def get_conv_gradinputs_shape(kernel_shape, top_shape,
border_mode, subsample, border_mode, subsample,
filter_dilation=None): filter_dilation=None,
num_groups=1):
""" """
This function tries to compute the image shape of convolution gradInputs. This function tries to compute the image shape of convolution gradInputs.
...@@ -273,6 +279,8 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape, ...@@ -273,6 +279,8 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape,
filter_dilation: tuple of int (symbolic or numeric). Its two or three filter_dilation: tuple of int (symbolic or numeric). Its two or three
elements correspond respectively to the dilation on height and elements correspond respectively to the dilation on height and
width axis. width axis.
num_groups: An int which specifies the number of separate groups to
be divided into.
Returns Returns
------- -------
...@@ -286,6 +294,9 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape, ...@@ -286,6 +294,9 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape,
if filter_dilation is None: if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int') filter_dilation = np.ones(len(subsample), dtype='int')
if num_groups > 1:
assert len(subsample) == 2
nkern = nkern * num_groups
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
out_shp = tuple(get_conv_gradinputs_shape_1axis( out_shp = tuple(get_conv_gradinputs_shape_1axis(
...@@ -512,7 +523,8 @@ def conv2d(input, ...@@ -512,7 +523,8 @@ def conv2d(input,
border_mode='valid', border_mode='valid',
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
"""This function will build the symbolic graph for convolving a mini-batch of a """This function will build the symbolic graph for convolving a mini-batch of a
stack of 2D inputs with a set of 2D filters. The implementation is modelled stack of 2D inputs with a set of 2D filters. The implementation is modelled
after Convolutional Neural Networks (CNN). after Convolutional Neural Networks (CNN).
...@@ -527,7 +539,8 @@ def conv2d(input, ...@@ -527,7 +539,8 @@ def conv2d(input,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return conv_op(input, filters) return conv_op(input, filters)
...@@ -637,7 +650,8 @@ def conv2d_grad_wrt_inputs(output_grad, ...@@ -637,7 +650,8 @@ def conv2d_grad_wrt_inputs(output_grad,
border_mode='valid', border_mode='valid',
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its inputs """Compute conv output gradient w.r.t its inputs
This function builds the symbolic graph for getting the This function builds the symbolic graph for getting the
...@@ -710,6 +724,9 @@ def conv2d_grad_wrt_inputs(output_grad, ...@@ -710,6 +724,9 @@ def conv2d_grad_wrt_inputs(output_grad,
filter_dilation : tuple of len 2 filter_dilation : tuple of len 2
The filter dilation used in the forward pass. The filter dilation used in the forward pass.
Also known as input striding. Also known as input striding.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns Returns
------- -------
...@@ -760,7 +777,8 @@ def conv2d_grad_wrt_inputs(output_grad, ...@@ -760,7 +777,8 @@ def conv2d_grad_wrt_inputs(output_grad,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return grad_input_op(filters, output_grad, input_shape[-2:]) return grad_input_op(filters, output_grad, input_shape[-2:])
...@@ -907,7 +925,8 @@ def conv2d_grad_wrt_weights(input, ...@@ -907,7 +925,8 @@ def conv2d_grad_wrt_weights(input,
border_mode='valid', border_mode='valid',
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its weights """Compute conv output gradient w.r.t its weights
This function will build the symbolic graph for getting the This function will build the symbolic graph for getting the
...@@ -972,6 +991,9 @@ def conv2d_grad_wrt_weights(input, ...@@ -972,6 +991,9 @@ def conv2d_grad_wrt_weights(input,
filter_dilation : tuple of len 2 filter_dilation : tuple of len 2
The filter dilation used in the forward pass. The filter dilation used in the forward pass.
Also known as input striding. Also known as input striding.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns Returns
------- -------
...@@ -1022,7 +1044,8 @@ def conv2d_grad_wrt_weights(input, ...@@ -1022,7 +1044,8 @@ def conv2d_grad_wrt_weights(input,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return gradWeight_op(input, output_grad, filter_shape[-2:]) return gradWeight_op(input, output_grad, filter_shape[-2:])
...@@ -1392,11 +1415,11 @@ class BaseAbstractConv(Op): ...@@ -1392,11 +1415,11 @@ class BaseAbstractConv(Op):
""" """
check_broadcast = False check_broadcast = False
__props__ = ('convdim', 'border_mode', 'subsample', 'filter_flip', __props__ = ('convdim', 'border_mode', 'subsample', 'filter_flip',
'imshp', 'kshp', 'filter_dilation') 'imshp', 'kshp', 'filter_dilation', 'num_groups')
def __init__(self, convdim, def __init__(self, convdim,
imshp=None, kshp=None, border_mode="valid", imshp=None, kshp=None, border_mode="valid",
subsample=None, filter_flip=True, filter_dilation=None): subsample=None, filter_flip=True, filter_dilation=None, num_groups=1):
self.convdim = convdim self.convdim = convdim
if convdim not in (2, 3): if convdim not in (2, 3):
...@@ -1458,6 +1481,11 @@ class BaseAbstractConv(Op): ...@@ -1458,6 +1481,11 @@ class BaseAbstractConv(Op):
if len(filter_dilation) != convdim: if len(filter_dilation) != convdim:
raise ValueError("filter_dilation must have {} elements".format(convdim)) raise ValueError("filter_dilation must have {} elements".format(convdim))
self.filter_dilation = tuple(filter_dilation) self.filter_dilation = tuple(filter_dilation)
if num_groups < 1:
raise ValueError("num_groups must have value greater than zero")
elif num_groups > 1 and convdim == 3:
raise ValueError("grouped convolution not supported for 3D convolutions")
self.num_groups = num_groups
def do_constant_folding(self, node): def do_constant_folding(self, node):
# Disable constant folding since there is no implementation. # Disable constant folding since there is no implementation.
...@@ -1471,20 +1499,20 @@ class BaseAbstractConv(Op): ...@@ -1471,20 +1499,20 @@ class BaseAbstractConv(Op):
# flops for any direction, sampling, padding, and border mode # flops for any direction, sampling, padding, and border mode
inputs, filters = inp inputs, filters = inp
outputs, = outp outputs, = outp
assert inputs[1] == filters[1] assert inputs[1] == (filters[1] * self.num_groups)
# nb mul and add by output pixel # nb mul and add by output pixel
flops = filters[2] * filters[3] * 2 flops = filters[2] * filters[3] * 2
# nb flops by output image # nb flops by output image
flops *= outputs[2] * outputs[3] flops *= outputs[2] * outputs[3]
# nb patch multiplied # nb patch multiplied
flops *= inputs[1] * filters[0] * inputs[0] flops *= inputs[1] * filters[0] * inputs[0] / self.num_groups
return flops return flops
else: else:
# TODO implement for convdim == 3 # TODO implement for convdim == 3
raise NotImplementedError( raise NotImplementedError(
'flops not implemented for convdim={}', self.convdim) 'flops not implemented for convdim={}', self.convdim)
def conv(self, img, kern, mode="valid", dilation=1): def conv(self, img, kern, mode="valid", dilation=1, num_groups=1):
""" """
Basic slow Python 2D or 3D convolution for DebugMode Basic slow Python 2D or 3D convolution for DebugMode
""" """
...@@ -1517,18 +1545,31 @@ class BaseAbstractConv(Op): ...@@ -1517,18 +1545,31 @@ class BaseAbstractConv(Op):
] = kern ] = kern
if self.convdim == 2: if self.convdim == 2:
if img.shape[1] % self.num_groups != 0:
raise ValueError(
'number of input channels must be divible by num_groups')
if kern.shape[0] % self.num_groups != 0:
raise ValueError(
'number of filters must be divisible by num_groups')
if img.shape[1] // num_groups != kern.shape[1]:
raise ValueError(
'the number of input channels in the kernel should '
'specify the number of channels of 1 group')
val = _valfrommode(mode) val = _valfrommode(mode)
bval = _bvalfromboundary('fill') bval = _bvalfromboundary('fill')
input_channel_offset = img.shape[1] // self.num_groups
output_channel_offset = kern.shape[0] // self.num_groups
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore', np.ComplexWarning) warnings.simplefilter('ignore', np.ComplexWarning)
for b in xrange(img.shape[0]): for b in xrange(img.shape[0]):
for n in xrange(kern.shape[0]): for g in xrange(self.num_groups):
for im0 in xrange(img.shape[1]): for n in xrange(output_channel_offset):
for im0 in xrange(input_channel_offset):
# some cast generates a warning here # some cast generates a warning here
out[b, n, ...] += _convolve2d(img[b, im0, ...], out[b, g * output_channel_offset + n, ...] += _convolve2d(img[b, g * input_channel_offset + im0, ...],
dilated_kern[n, im0, ...], dilated_kern[g * output_channel_offset + n,
1, val, bval, 0) im0, ...], 1, val, bval, 0)
elif self.convdim == 3: elif self.convdim == 3:
for b in xrange(img.shape[0]): for b in xrange(img.shape[0]):
for n in xrange(kern.shape[0]): for n in xrange(kern.shape[0]):
...@@ -1554,13 +1595,15 @@ class AbstractConv(BaseAbstractConv): ...@@ -1554,13 +1595,15 @@ class AbstractConv(BaseAbstractConv):
border_mode="valid", border_mode="valid",
subsample=None, subsample=None,
filter_flip=True, filter_flip=True,
filter_dilation=None): filter_dilation=None,
num_groups=1):
super(AbstractConv, self).__init__(convdim=convdim, super(AbstractConv, self).__init__(convdim=convdim,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def make_node(self, img, kern): def make_node(self, img, kern):
# Make sure both inputs are Variables with the same Type # Make sure both inputs are Variables with the same Type
...@@ -1622,7 +1665,7 @@ class AbstractConv(BaseAbstractConv): ...@@ -1622,7 +1665,7 @@ class AbstractConv(BaseAbstractConv):
img = new_img img = new_img
if not self.filter_flip: if not self.filter_flip:
kern = kern[(slice(None), slice(None)) + (slice(None, None, -1),) * self.convdim] kern = kern[(slice(None), slice(None)) + (slice(None, None, -1),) * self.convdim]
conv_out = self.conv(img, kern, mode="valid", dilation=self.filter_dilation) conv_out = self.conv(img, kern, mode="valid", dilation=self.filter_dilation, num_groups=self.num_groups)
conv_out = conv_out[(slice(None), slice(None)) + conv_out = conv_out[(slice(None), slice(None)) +
tuple(slice(None, None, self.subsample[i]) tuple(slice(None, None, self.subsample[i])
for i in range(self.convdim))] for i in range(self.convdim))]
...@@ -1630,6 +1673,9 @@ class AbstractConv(BaseAbstractConv): ...@@ -1630,6 +1673,9 @@ class AbstractConv(BaseAbstractConv):
o[0] = node.outputs[0].type.filter(conv_out) o[0] = node.outputs[0].type.filter(conv_out)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if self.num_groups > 1:
raise NotImplementedError(
'Rop not implemented for grouped convolutions')
rval = None rval = None
if eval_points[0] is not None: if eval_points[0] is not None:
rval = self.make_node(eval_points[0], inputs[1]).outputs[0] rval = self.make_node(eval_points[0], inputs[1]).outputs[0]
...@@ -1668,13 +1714,15 @@ class AbstractConv2d(AbstractConv): ...@@ -1668,13 +1714,15 @@ class AbstractConv2d(AbstractConv):
border_mode="valid", border_mode="valid",
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
super(AbstractConv2d, self).__init__(convdim=2, super(AbstractConv2d, self).__init__(convdim=2,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, weights = inp bottom, weights = inp
...@@ -1684,13 +1732,15 @@ class AbstractConv2d(AbstractConv): ...@@ -1684,13 +1732,15 @@ class AbstractConv2d(AbstractConv):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
num_groups=self.num_groups)(
weights, top, bottom.shape[-2:], add_assert_shape=False) weights, top, bottom.shape[-2:], add_assert_shape=False)
d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp, d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
num_groups=self.num_groups)(
bottom, top, weights.shape[-2:], add_assert_shape=False) bottom, top, weights.shape[-2:], add_assert_shape=False)
...@@ -1772,13 +1822,15 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -1772,13 +1822,15 @@ class AbstractConv_gradWeights(BaseAbstractConv):
border_mode="valid", border_mode="valid",
subsample=None, subsample=None,
filter_flip=True, filter_flip=True,
filter_dilation=None): filter_dilation=None,
num_groups=1):
super(AbstractConv_gradWeights, self).__init__(convdim=convdim, super(AbstractConv_gradWeights, self).__init__(convdim=convdim,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
# Update shape/height_width # Update shape/height_width
def make_node(self, img, topgrad, shape, add_assert_shape=True): def make_node(self, img, topgrad, shape, add_assert_shape=True):
...@@ -1856,7 +1908,19 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -1856,7 +1908,19 @@ class AbstractConv_gradWeights(BaseAbstractConv):
(slice(None, None, -1),) * self.convdim) (slice(None, None, -1),) * self.convdim)
topgrad = topgrad.transpose(axes_order)[flip_filters] topgrad = topgrad.transpose(axes_order)[flip_filters]
img = img.transpose(axes_order) img = img.transpose(axes_order)
kern = self.conv(img, topgrad, mode="valid")
def correct_for_groups(mat):
mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
mat = mat.transpose((1, 0, 2, 3, 4))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-2:])
return mat
if self.num_groups > 1:
img = correct_for_groups(img)
kern = self.conv(img, topgrad, mode="valid", num_groups=self.num_groups)
if any(self.filter_dilation[i] > 1 for i in range(self.convdim)): if any(self.filter_dilation[i] > 1 for i in range(self.convdim)):
kern = kern[(slice(None), slice(None)) + kern = kern[(slice(None), slice(None)) +
tuple(slice(None, None, self.filter_dilation[i]) tuple(slice(None, None, self.filter_dilation[i])
...@@ -1878,6 +1942,10 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -1878,6 +1942,10 @@ class AbstractConv_gradWeights(BaseAbstractConv):
imshp = input_shapes[0] imshp = input_shapes[0]
topshp = input_shapes[1] topshp = input_shapes[1]
kshp = self.kshp[:] if self.kshp is not None else [None] * (2 + self.convdim) kshp = self.kshp[:] if self.kshp is not None else [None] * (2 + self.convdim)
if self.num_groups > 1:
fallback_kshp = ([topshp[1], imshp[1] // self.num_groups] +
[node.inputs[2][i] for i in range(self.convdim)])
else:
fallback_kshp = ([topshp[1], imshp[1]] + fallback_kshp = ([topshp[1], imshp[1]] +
[node.inputs[2][i] for i in range(self.convdim)]) [node.inputs[2][i] for i in range(self.convdim)])
kshp = [fallback_kshp[i] if kshp[i] is None else kshp[i] kshp = [fallback_kshp[i] if kshp[i] is None else kshp[i]
...@@ -1901,13 +1969,15 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights): ...@@ -1901,13 +1969,15 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights):
border_mode="valid", border_mode="valid",
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
super(AbstractConv2d_gradWeights, self).__init__(convdim=2, super(AbstractConv2d_gradWeights, self).__init__(convdim=2,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, top = inp[:2] bottom, top = inp[:2]
...@@ -1916,7 +1986,8 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights): ...@@ -1916,7 +1986,8 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(weights, self.filter_dilation,
self.num_groups)(weights,
top, top,
bottom.shape[-2:]) bottom.shape[-2:])
d_top = AbstractConv2d(self.imshp, d_top = AbstractConv2d(self.imshp,
...@@ -1924,7 +1995,8 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights): ...@@ -1924,7 +1995,8 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used # Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer # for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable. # that the dimensions are broadcastable.
...@@ -2011,13 +2083,15 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2011,13 +2083,15 @@ class AbstractConv_gradInputs(BaseAbstractConv):
border_mode="valid", border_mode="valid",
subsample=None, subsample=None,
filter_flip=True, filter_flip=True,
filter_dilation=None): filter_dilation=None,
num_groups=1):
super(AbstractConv_gradInputs, self).__init__(convdim=convdim, super(AbstractConv_gradInputs, self).__init__(convdim=convdim,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
# Update shape/height_width # Update shape/height_width
def make_node(self, kern, topgrad, shape, add_assert_shape=True): def make_node(self, kern, topgrad, shape, add_assert_shape=True):
...@@ -2041,6 +2115,10 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2041,6 +2115,10 @@ class AbstractConv_gradInputs(BaseAbstractConv):
'filters does not match given kshp.') 'filters does not match given kshp.')
shape = as_tensor_variable(shape) shape = as_tensor_variable(shape)
if self.num_groups > 1:
broadcastable = [topgrad.type.broadcastable[0],
False] + ([False] * self.convdim)
else:
broadcastable = [topgrad.type.broadcastable[0], broadcastable = [topgrad.type.broadcastable[0],
kern.type.broadcastable[1]] + ([False] * self.convdim) kern.type.broadcastable[1]] + ([False] * self.convdim)
output = kern.type.clone(broadcastable=broadcastable)() output = kern.type.clone(broadcastable=broadcastable)()
...@@ -2097,10 +2175,20 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2097,10 +2175,20 @@ class AbstractConv_gradInputs(BaseAbstractConv):
axes_order = (1, 0) + tuple(range(2, self.convdim + 2)) axes_order = (1, 0) + tuple(range(2, self.convdim + 2))
flip_filters = ((slice(None), slice(None)) + flip_filters = ((slice(None), slice(None)) +
(slice(None, None, -1),) * self.convdim) (slice(None, None, -1),) * self.convdim)
def correct_for_groups(mat):
mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
mat = mat.transpose((1, 0, 2, 3, 4))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-2:])
return mat
kern = correct_for_groups(kern)
kern = kern.transpose(axes_order) kern = kern.transpose(axes_order)
if self.filter_flip: if self.filter_flip:
topgrad = topgrad[flip_filters] topgrad = topgrad[flip_filters]
img = self.conv(topgrad, kern, mode="full", dilation=self.filter_dilation) img = self.conv(topgrad, kern, mode="full", dilation=self.filter_dilation, num_groups=self.num_groups)
if self.filter_flip: if self.filter_flip:
img = img[flip_filters] img = img[flip_filters]
if any(p > 0 for p in pad): if any(p > 0 for p in pad):
...@@ -2120,6 +2208,10 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -2120,6 +2208,10 @@ class AbstractConv_gradInputs(BaseAbstractConv):
kshp = input_shapes[0] kshp = input_shapes[0]
topshp = input_shapes[1] topshp = input_shapes[1]
imshp = self.imshp[:] if self.imshp is not None else [None] * (2 + self.convdim) imshp = self.imshp[:] if self.imshp is not None else [None] * (2 + self.convdim)
if self.num_groups > 1:
fallback_imshp = ([topshp[0], kshp[1] * self.num_groups] +
[node.inputs[2][i] for i in range(self.convdim)])
else:
fallback_imshp = ([topshp[0], kshp[1]] + fallback_imshp = ([topshp[0], kshp[1]] +
[node.inputs[2][i] for i in range(self.convdim)]) [node.inputs[2][i] for i in range(self.convdim)])
imshp = [fallback_imshp[i] if imshp[i] is None else imshp[i] imshp = [fallback_imshp[i] if imshp[i] is None else imshp[i]
...@@ -2144,13 +2236,15 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs): ...@@ -2144,13 +2236,15 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs):
border_mode="valid", border_mode="valid",
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
super(AbstractConv2d_gradInputs, self).__init__(convdim=2, super(AbstractConv2d_gradInputs, self).__init__(convdim=2,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
weights, top = inp[:2] weights, top = inp[:2]
...@@ -2159,14 +2253,16 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs): ...@@ -2159,14 +2253,16 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
bottom, top, bottom, top,
weights.shape[-2:]) weights.shape[-2:])
d_top = AbstractConv2d(self.imshp, self.kshp, d_top = AbstractConv2d(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used # Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer # for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable. # that the dimensions are broadcastable.
......
...@@ -40,9 +40,11 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -40,9 +40,11 @@ class BaseCorrMM(gof.OpenMPOp):
Perform subsampling of the output (default: (1, 1)). Perform subsampling of the output (default: (1, 1)).
filter_dilation filter_dilation
Perform dilated correlation (default: (1,1)) Perform dilated correlation (default: (1,1))
num_groups
Perform grouped convolutions (default: 1)
""" """
check_broadcast = False check_broadcast = False
__props__ = ('border_mode', 'subsample', 'filter_dilation') __props__ = ('border_mode', 'subsample', 'filter_dilation', 'num_groups')
_direction = None _direction = None
...@@ -51,10 +53,11 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -51,10 +53,11 @@ class BaseCorrMM(gof.OpenMPOp):
('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')), # 2 ('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')), # 2
dH=int64, dW=int64, dH=int64, dW=int64,
dilH=int64, dilW=int64, dilH=int64, dilW=int64,
padH=int64, padW=int64) padH=int64, padW=int64,
num_groups=int64)
def __init__(self, border_mode="valid", subsample=(1, 1), def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1), openmp=None): filter_dilation=(1, 1), num_groups=1, openmp=None):
super(BaseCorrMM, self).__init__(openmp=openmp) super(BaseCorrMM, self).__init__(openmp=openmp)
if isinstance(border_mode, integer_types): if isinstance(border_mode, integer_types):
if border_mode < 0: if border_mode < 0:
...@@ -97,6 +100,9 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -97,6 +100,9 @@ class BaseCorrMM(gof.OpenMPOp):
if self._direction not in ["forward", "backprop weights", "backprop inputs"]: if self._direction not in ["forward", "backprop weights", "backprop inputs"]:
raise ValueError("_direction must be one of 'forward', " raise ValueError("_direction must be one of 'forward', "
"'backprop weights', 'backprop inputs'") "'backprop weights', 'backprop inputs'")
if num_groups < 1:
raise ValueError("Number of groups should be greater than 0")
self.num_groups = num_groups
@property @property
def pad(self): def pad(self):
...@@ -124,11 +130,12 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -124,11 +130,12 @@ class BaseCorrMM(gof.OpenMPOp):
padW = property(lambda self: self.pad[1]) padW = property(lambda self: self.pad[1])
def __str__(self): def __str__(self):
return '%s{%s, %s, %s}' % ( return '%s{%s, %s, %s, %s}' % (
self.__class__.__name__, self.__class__.__name__,
self.border_mode, self.border_mode,
str(self.subsample), str(self.subsample),
str(self.filter_dilation)) str(self.filter_dilation),
str(self.num_groups))
@staticmethod @staticmethod
def as_common_dtype(in1, in2): def as_common_dtype(in1, in2):
...@@ -138,6 +145,11 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -138,6 +145,11 @@ class BaseCorrMM(gof.OpenMPOp):
dtype = theano.scalar.upcast(in1.dtype, in2.dtype) dtype = theano.scalar.upcast(in1.dtype, in2.dtype)
return in1.astype(dtype), in2.astype(dtype) return in1.astype(dtype), in2.astype(dtype)
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def c_support_code(self): def c_support_code(self):
ccodes = blas_headers.blas_header_text() ccodes = blas_headers.blas_header_text()
if self.blas_type == 'openblas': if self.blas_type == 'openblas':
...@@ -167,7 +179,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -167,7 +179,7 @@ class BaseCorrMM(gof.OpenMPOp):
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files # raise this whenever modifying any of the support_code_files
return (6, self.openmp, blas_header_version()) return (7, self.openmp, blas_header_version())
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of # REMEMBER TO RAISE c_code_cache_version when changing any of
...@@ -274,6 +286,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -274,6 +286,7 @@ class BaseCorrMM(gof.OpenMPOp):
int dilW = %(params)s->dilW; int dilW = %(params)s->dilW;
int padH = %(params)s->padH; int padH = %(params)s->padH;
int padW = %(params)s->padW; int padW = %(params)s->padW;
int numgroups = %(params)s->num_groups;
PyArrayObject * bottom = %(bottom)s; PyArrayObject * bottom = %(bottom)s;
PyArrayObject * weights = %(weights)s; PyArrayObject * weights = %(weights)s;
...@@ -386,7 +399,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -386,7 +399,7 @@ class BaseCorrMM(gof.OpenMPOp):
// output is weights: (num_filters, num_channels, height, width) // output is weights: (num_filters, num_channels, height, width)
// height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1 // height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1
out_dim[0] = (npy_intp)PyArray_DIMS(top)[1]; out_dim[0] = (npy_intp)PyArray_DIMS(top)[1];
out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1]; out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1] / numgroups;
out_dim[2] = (npy_intp)kH; // already inferred further above out_dim[2] = (npy_intp)kH; // already inferred further above
out_dim[3] = (npy_intp)kW; // how convenient out_dim[3] = (npy_intp)kW; // how convenient
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0) if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
...@@ -409,7 +422,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -409,7 +422,7 @@ class BaseCorrMM(gof.OpenMPOp):
// output is bottom: (batchsize, num_channels, height, width) // output is bottom: (batchsize, num_channels, height, width)
// height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad // height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
out_dim[0] = (npy_intp)PyArray_DIMS(top)[0]; out_dim[0] = (npy_intp)PyArray_DIMS(top)[0];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1]; out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1] * numgroups;
out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH); out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH);
out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW); out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW);
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0) if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
...@@ -465,7 +478,7 @@ class BaseCorrMM(gof.OpenMPOp): ...@@ -465,7 +478,7 @@ class BaseCorrMM(gof.OpenMPOp):
} }
// Call corrMM code // Call corrMM code
out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW, padH, padW); out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW, padH, padW, numgroups );
if (out2==NULL){ if (out2==NULL){
%(fail)s %(fail)s
} }
...@@ -541,11 +554,13 @@ class CorrMM(BaseCorrMM): ...@@ -541,11 +554,13 @@ class CorrMM(BaseCorrMM):
top, = grads top, = grads
d_bottom = CorrMM_gradInputs(self.border_mode, d_bottom = CorrMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(weights, top, self.filter_dilation,
self.num_groups)(weights, top,
bottom.shape[-2:]) bottom.shape[-2:])
d_weights = CorrMM_gradWeights(self.border_mode, d_weights = CorrMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, top, self.filter_dilation,
self.num_groups)(bottom, top,
weights.shape[-2:]) weights.shape[-2:])
return d_bottom, d_weights return d_bottom, d_weights
...@@ -600,6 +615,7 @@ class CorrMM_gradWeights(BaseCorrMM): ...@@ -600,6 +615,7 @@ class CorrMM_gradWeights(BaseCorrMM):
imshp = input_shape[0] imshp = input_shape[0]
topshp = input_shape[1] topshp = input_shape[1]
ssize, imshp = imshp[1], list(imshp[2:]) ssize, imshp = imshp[1], list(imshp[2:])
ssize = ssize // self.num_groups
nkern, topshp = topshp[1], list(topshp[2:]) nkern, topshp = topshp[1], list(topshp[2:])
height_width = node.inputs[-2:] height_width = node.inputs[-2:]
if ((dH != 1) or (padH == -1)): if ((dH != 1) or (padH == -1)):
...@@ -632,11 +648,13 @@ class CorrMM_gradWeights(BaseCorrMM): ...@@ -632,11 +648,13 @@ class CorrMM_gradWeights(BaseCorrMM):
weights, = grads weights, = grads
d_bottom = CorrMM_gradInputs(self.border_mode, d_bottom = CorrMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(weights, top, self.filter_dilation,
self.num_groups)(weights, top,
bottom.shape[-2:]) bottom.shape[-2:])
d_top = CorrMM(self.border_mode, d_top = CorrMM(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
d_height_width = ((theano.gradient.DisconnectedType()(),) * 2 d_height_width = ((theano.gradient.DisconnectedType()(),) * 2
if len(inp) == 4 else ()) if len(inp) == 4 else ())
return (d_bottom, d_top) + d_height_width return (d_bottom, d_top) + d_height_width
...@@ -678,6 +696,10 @@ class CorrMM_gradInputs(BaseCorrMM): ...@@ -678,6 +696,10 @@ class CorrMM_gradInputs(BaseCorrMM):
height_width = [as_tensor_variable(shape[0]).astype('int64'), height_width = [as_tensor_variable(shape[0]).astype('int64'),
as_tensor_variable(shape[1]).astype('int64')] as_tensor_variable(shape[1]).astype('int64')]
if self.num_groups > 1:
broadcastable = [topgrad.type.broadcastable[0], False,
False, False]
else:
broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1], broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1],
False, False] False, False]
dtype = kern.type.dtype dtype = kern.type.dtype
...@@ -698,6 +720,7 @@ class CorrMM_gradInputs(BaseCorrMM): ...@@ -698,6 +720,7 @@ class CorrMM_gradInputs(BaseCorrMM):
kshp = input_shape[0] kshp = input_shape[0]
topshp = input_shape[1] topshp = input_shape[1]
ssize, kshp = kshp[1], list(kshp[2:]) ssize, kshp = kshp[1], list(kshp[2:])
ssize = ssize * self.num_groups
bsize, topshp = topshp[0], list(topshp[2:]) bsize, topshp = topshp[0], list(topshp[2:])
height_width = node.inputs[-2:] height_width = node.inputs[-2:]
if padH == -1: if padH == -1:
...@@ -738,12 +761,14 @@ class CorrMM_gradInputs(BaseCorrMM): ...@@ -738,12 +761,14 @@ class CorrMM_gradInputs(BaseCorrMM):
bottom, = grads bottom, = grads
d_weights = CorrMM_gradWeights(self.border_mode, d_weights = CorrMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, self.filter_dilation,
self.num_groups)(bottom,
top, top,
weights.shape[-2:]) weights.shape[-2:])
d_top = CorrMM(self.border_mode, d_top = CorrMM(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
d_height_width = ((theano.gradient.DisconnectedType()(),) * d_height_width = ((theano.gradient.DisconnectedType()(),) *
2 if len(inp) == 4 else ()) 2 if len(inp) == 4 else ())
return (d_weights, d_top) + d_height_width return (d_weights, d_top) + d_height_width
......
...@@ -106,7 +106,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -106,7 +106,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
const int dilH = 1, const int dilH = 1,
const int dilW = 1, const int dilW = 1,
const int padH = 0, const int padH = 0,
const int padW = 0) const int padW = 0,
const int numgroups = 1)
{ {
if (PyArray_NDIM(bottom) != 4) if (PyArray_NDIM(bottom) != 4)
{ {
...@@ -155,7 +156,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -155,7 +156,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
const int nFilters = PyArray_DIMS(weight)[0]; const int nFilters = PyArray_DIMS(weight)[0];
const int kH = PyArray_DIMS(weight)[2]; const int kH = PyArray_DIMS(weight)[2];
const int kW = PyArray_DIMS(weight)[3]; const int kW = PyArray_DIMS(weight)[3];
if (nChannels != PyArray_DIMS(weight)[1]) { if (nChannels != (PyArray_DIMS(weight)[1] * numgroups)) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"CorrMM images and kernel must have the same stack size\n"); "CorrMM images and kernel must have the same stack size\n");
return NULL; return NULL;
...@@ -214,12 +215,16 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -214,12 +215,16 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
} }
// Define some useful variables // Define some useful variables
const int bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f; const int batch_bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f;
const int top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f; const int group_bottom_stride = (PyArray_STRIDES(bottom)[1] * nChannels / numgroups)/%(n_bytes)f;
const int K_ = col_dim[1]; const int batch_top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f;
const int group_top_stride = (PyArray_STRIDES(top)[1] * nFilters / numgroups)/%(n_bytes)f;
const int K_ = col_dim[1] / numgroups;
const int N_ = col_dim[2]; const int N_ = col_dim[2];
const int col_stride = (K_ * N_); const int col_stride = (K_ * N_ * numgroups);
const int M_ = nFilters; const int group_col_stride = (K_ * N_);
const int group_weight_stride = (PyArray_STRIDES(weight)[0] * nFilters / numgroups)/%(n_bytes)f;
const int M_ = nFilters / numgroups;
const %(c_float_type)s one = 1.0; const %(c_float_type)s one = 1.0;
const %(c_float_type)s zero = 0.0; const %(c_float_type)s zero = 0.0;
char NTrans = 'N'; char NTrans = 'N';
...@@ -253,17 +258,19 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -253,17 +258,19 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
// First, im2col // First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight, im2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride, nChannels,
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW, bottomHeight,bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
for ( int g = 0; g < numgroups; ++g){
// Second, gemm // Second, gemm
%(gemm)s(&NTrans, &NTrans, %(gemm)s(&NTrans, &NTrans,
&N_, &M_, &K_, &N_, &M_, &K_,
&one, &one,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride, &N_, (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_, (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
&zero, &zero,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_); (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_);
}
} }
// Restore to previous blas threads // Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved); %(blas_set_num_threads)s(blas_threads_saved);
...@@ -304,7 +311,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -304,7 +311,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
output = weight; output = weight;
npy_intp weight_dim[2]; npy_intp weight_dim[2];
weight_dim[0] = (npy_intp)max_threads; weight_dim[0] = (npy_intp)max_threads;
weight_dim[1] = (npy_intp)(M_ * K_); weight_dim[1] = (npy_intp)(M_ * K_ * numgroups);
PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2, PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2,
weight_dim, PyArray_TYPE(weight), 0); weight_dim, PyArray_TYPE(weight), 0);
...@@ -326,9 +333,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -326,9 +333,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
// First, im2col // First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight, im2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride,
bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW, nChannels, bottomHeight,bottomWidth, kH, kW, dilH, dilW, padH, padW, dH, dW,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
for(int g = 0; g < numgroups; ++g){
// Second, gemm // Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
...@@ -336,12 +344,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -336,12 +344,13 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
%(gemm)s(&Trans, &NTrans, %(gemm)s(&Trans, &NTrans,
&K_, &M_, &N_, &K_, &M_, &N_,
&one, &one,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_, (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_, (%(float_type)s*)PyArray_DATA(top) + g * group_top_stride + n * batch_top_stride, &N_,
(n == 0) ? &zero : &one, (n == 0) ? &zero : &one,
(%(float_type)s*)PyArray_DATA(local_weight) + (%(float_type)s*)PyArray_DATA(local_weight) + g * group_weight_stride +
tid * weight_dim[1], &K_); tid * weight_dim[1], &K_);
} }
}
// Restore to previous blas threads // Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved); %(blas_set_num_threads)s(blas_threads_saved);
...@@ -401,19 +410,21 @@ PyArrayObject* corrMM(PyArrayObject* bottom, ...@@ -401,19 +410,21 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
%(blas_set_num_threads)s(1); %(blas_set_num_threads)s(1);
%(omp_flags)s %(omp_flags)s
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
// gemm into columns
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
for ( int g = 0;g < numgroups; ++g){
// gemm into columns
%(gemm)s(&NTrans, &Trans, %(gemm)s(&NTrans, &Trans,
&N_, &K_, &M_, &N_, &K_, &M_,
&one, &one,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_, (%(float_type)s*)PyArray_DATA(top) + g * group_top_stride + n * batch_top_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_, (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
&zero, &zero,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_); (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_);
}
// col2im back to the data // col2im back to the data
col2im((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels, bottomHeight, bottomWidth, col2im((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels, bottomHeight, bottomWidth,
kH, kW, dilH, dilW, padH, padW, kH, kW, dilH, dilW, padH, padW,
dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride); dH, dW, (%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride);
} }
// Restore to previous blas threads // Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved); %(blas_set_num_threads)s(blas_threads_saved);
......
...@@ -88,7 +88,9 @@ def local_abstractconv_gemm(node): ...@@ -88,7 +88,9 @@ def local_abstractconv_gemm(node):
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
rval = CorrMM(border_mode=node.op.border_mode, rval = CorrMM(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(img, kern) filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(img, kern)
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(node.outputs[0], rval)
return [rval] return [rval]
...@@ -133,7 +135,8 @@ def local_abstractconv_gradweight_gemm(node): ...@@ -133,7 +135,8 @@ def local_abstractconv_gradweight_gemm(node):
rval = CorrMM_gradWeights(border_mode=node.op.border_mode, rval = CorrMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(img, topgrad, shape) filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(img, topgrad, shape)
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(node.outputs[0], rval)
# need to flip the kernel if necessary # need to flip the kernel if necessary
...@@ -190,7 +193,8 @@ def local_abstractconv_gradinputs_gemm(node): ...@@ -190,7 +193,8 @@ def local_abstractconv_gradinputs_gemm(node):
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
rval = CorrMM_gradInputs(border_mode=node.op.border_mode, rval = CorrMM_gradInputs(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(kern, topgrad, filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(kern, topgrad,
shape) shape)
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(node.outputs[0], rval)
...@@ -238,6 +242,8 @@ def local_conv2d_cpu(node): ...@@ -238,6 +242,8 @@ def local_conv2d_cpu(node):
if not node.op.filter_flip: if not node.op.filter_flip:
# Not tested yet # Not tested yet
return None return None
if node.op.num_groups > 1:
return None
rval = conv2d(img, kern, rval = conv2d(img, kern,
node.op.imshp, node.op.kshp, node.op.imshp, node.op.kshp,
...@@ -295,6 +301,8 @@ def local_conv2d_gradweight_cpu(node): ...@@ -295,6 +301,8 @@ def local_conv2d_gradweight_cpu(node):
if not node.op.filter_flip: if not node.op.filter_flip:
# Not tested yet # Not tested yet
return return
if node.op.num_groups > 1:
return None
if node.op.border_mode == 'valid' and \ if node.op.border_mode == 'valid' and \
(node.op.subsample != (1, 1)): (node.op.subsample != (1, 1)):
...@@ -447,6 +455,8 @@ def local_conv2d_gradinputs_cpu(node): ...@@ -447,6 +455,8 @@ def local_conv2d_gradinputs_cpu(node):
if not node.op.filter_flip: if not node.op.filter_flip:
# Not tested yet # Not tested yet
return None return None
if node.op.num_groups > 1:
return None
# Conv 3d implementation, needed when subsample > 2 # Conv 3d implementation, needed when subsample > 2
if node.op.border_mode == 'valid' and node.op.subsample != (1, 1): if node.op.border_mode == 'valid' and node.op.subsample != (1, 1):
......
...@@ -1699,3 +1699,158 @@ class TestConv2dGrads(unittest.TestCase): ...@@ -1699,3 +1699,158 @@ class TestConv2dGrads(unittest.TestCase):
) )
f_new = theano.function([self.x, self.output_grad_wrt], conv_wrt_w_out) f_new = theano.function([self.x, self.output_grad_wrt], conv_wrt_w_out)
utt.assert_allclose(f_new(input_val, out_grad_val), f_old(input_val, filter_val, out_grad_val)) utt.assert_allclose(f_new(input_val, out_grad_val), f_old(input_val, filter_val, out_grad_val))
class Grouped_conv_noOptim(unittest.TestCase):
conv2d = theano.tensor.nnet.abstract_conv.AbstractConv2d
conv2d_gradw = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights
conv2d_gradi = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs
conv2d_op = theano.tensor.nnet.abstract_conv.AbstractConv2d
conv2d_gradw_op = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights
conv2d_gradi_op = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs
mode = theano.Mode(optimizer=None)
flip_filter = False
is_dnn = False
def setUp(self):
self.num_groups = [3, 2, 4, 4]
self.border_mode = 'valid'
self.subsample = (1, 1)
self.img_shape = [(5, 6, 5, 5), (4, 4, 7, 5), (3, 8, 5, 3), (2, 4, 7, 7)]
self.kern_shape = [(6, 2, 3, 3), (6, 2, 5, 3), (4, 2, 3, 3), (4, 1, 3, 5)]
self.top_shape = [(5, 6, 3, 3), (4, 6, 3, 3), (3, 4, 3, 1), (2, 4, 5, 3)]
self.filter_dilation = (1, 1)
self.ref_mode = 'FAST_RUN'
if theano.config.cxx == "":
raise SkipTest("CorrMM needs cxx")
def test_fwd(self):
img_sym = theano.tensor.tensor4('img')
kern_sym = theano.tensor.tensor4('kern')
for imshp, kshp, groups in zip(self.img_shape, self.kern_shape, self.num_groups):
img = np.random.random(imshp).astype(theano.config.floatX)
kern = np.random.random(kshp).astype(theano.config.floatX)
split_imgs = np.split(img, groups, axis=1)
split_kern = np.split(kern, groups, axis=0)
grouped_conv_op = self.conv2d(border_mode=self.border_mode,
subsample=self.subsample,
filter_dilation=self.filter_dilation,
num_groups=groups)
if self.flip_filter:
grouped_conv_output = grouped_conv_op(img_sym, kern_sym[:, :, ::-1, ::-1])
else:
grouped_conv_output = grouped_conv_op(img_sym, kern_sym)
grouped_func = theano.function([img_sym, kern_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d_op)
for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(img, kern)
ref_conv_op = conv2d_corr(img_sym,
kern_sym,
border_mode=self.border_mode,
subsample=self.subsample,
filter_dilation=self.filter_dilation)
ref_func = theano.function([img_sym, kern_sym], ref_conv_op,
mode=self.ref_mode)
ref_concat_output = [ref_func(img_arr, kern_arr)
for img_arr, kern_arr in zip(split_imgs, split_kern)]
ref_concat_output = np.concatenate(ref_concat_output, axis=1)
utt.assert_allclose(grouped_output, ref_concat_output)
utt.verify_grad(grouped_conv_op,
[img, kern],
mode=self.mode,
eps=1)
def test_gradweights(self):
img_sym = theano.tensor.tensor4('img')
top_sym = theano.tensor.tensor4('top')
for imshp, kshp, tshp, groups in zip(self.img_shape, self.kern_shape, self.top_shape, self.num_groups):
img = np.random.random(imshp).astype(theano.config.floatX)
top = np.random.random(tshp).astype(theano.config.floatX)
split_imgs = np.split(img, groups, axis=1)
split_top = np.split(top, groups, axis=1)
grouped_convgrad_op = self.conv2d_gradw(border_mode=self.border_mode,
subsample=self.subsample,
filter_dilation=self.filter_dilation,
num_groups=groups)
grouped_conv_output = grouped_convgrad_op(img_sym,
top_sym,
tensor.as_tensor_variable(kshp if self.is_dnn else kshp[-2:]))
if self.flip_filter:
grouped_conv_output = grouped_conv_output[:, :, ::-1, ::-1]
grouped_func = theano.function([img_sym, top_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d_gradw_op)
for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(img, top)
ref_conv_op = conv2d_corr_gw(img_sym,
top_sym,
kshp,
border_mode=self.border_mode,
subsample=self.subsample,
filter_dilation=self.filter_dilation)
ref_func = theano.function([img_sym, top_sym], ref_conv_op,
mode=self.ref_mode)
ref_concat_output = [ref_func(img_arr, top_arr)
for img_arr, top_arr in zip(split_imgs, split_top)]
ref_concat_output = np.concatenate(ref_concat_output, axis=0)
utt.assert_allclose(grouped_output, ref_concat_output)
def conv_gradweight(inputs_val, output_val):
return grouped_convgrad_op(inputs_val, output_val,
tensor.as_tensor_variable(kshp if self.is_dnn else kshp[-2:]))
utt.verify_grad(conv_gradweight,
[img, top],
mode=self.mode, eps=1)
def test_gradinputs(self):
kern_sym = theano.tensor.tensor4('kern')
top_sym = theano.tensor.tensor4('top')
for imshp, kshp, tshp, groups in zip(self.img_shape, self.kern_shape, self.top_shape, self.num_groups):
kern = np.random.random(kshp).astype(theano.config.floatX)
top = np.random.random(tshp).astype(theano.config.floatX)
split_kerns = np.split(kern, groups, axis=0)
split_top = np.split(top, groups, axis=1)
grouped_convgrad_op = self.conv2d_gradi(border_mode=self.border_mode,
subsample=self.subsample,
filter_dilation=self.filter_dilation,
num_groups=groups)
if self.flip_filter:
grouped_conv_output = grouped_convgrad_op(kern_sym[:, :, ::-1, ::-1], top_sym, tensor.as_tensor_variable(imshp[-2:]))
else:
grouped_conv_output = grouped_convgrad_op(kern_sym,
top_sym,
tensor.as_tensor_variable(imshp if self.is_dnn else imshp[-2:]))
grouped_func = theano.function([kern_sym, top_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d_gradi_op)
for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(kern, top)
ref_conv_op = conv2d_corr_gi(kern_sym,
top_sym,
imshp,
border_mode=self.border_mode,
subsample=self.subsample,
filter_dilation=self.filter_dilation)
ref_func = theano.function([kern_sym, top_sym], ref_conv_op,
mode=self.ref_mode)
ref_concat_output = [ref_func(kern_arr, top_arr)
for kern_arr, top_arr in zip(split_kerns, split_top)]
ref_concat_output = np.concatenate(ref_concat_output, axis=1)
utt.assert_allclose(grouped_output, ref_concat_output)
def conv_gradinputs(filters_val, output_val):
return grouped_convgrad_op(filters_val, output_val,
tensor.as_tensor_variable(imshp if self.is_dnn else imshp[-2:]))
utt.verify_grad(conv_gradinputs,
[kern, top],
mode=self.mode, eps=1)
...@@ -10,6 +10,7 @@ import theano ...@@ -10,6 +10,7 @@ import theano
import theano.tensor as T import theano.tensor as T
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet import corr, conv from theano.tensor.nnet import corr, conv
from theano.tensor.nnet.tests.test_abstract_conv import Grouped_conv_noOptim
class TestCorr2D(utt.InferShapeTester): class TestCorr2D(utt.InferShapeTester):
...@@ -416,6 +417,49 @@ class TestCorr2D(utt.InferShapeTester): ...@@ -416,6 +417,49 @@ class TestCorr2D(utt.InferShapeTester):
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 2, non_contiguous=True) self.validate((3, 2, 7, 5), (5, 2, 2, 3), 2, non_contiguous=True)
class TestGroupCorr2d(Grouped_conv_noOptim):
if theano.config.mode == "FAST_COMPILE":
mode = theano.compile.get_mode("FAST_RUN")
else:
mode = None
conv2d = corr.CorrMM
conv2d_gradw = corr.CorrMM_gradWeights
conv2d_gradi = corr.CorrMM_gradInputs
conv2d_op = corr.CorrMM
conv2d_gradw_op = corr.CorrMM_gradWeights
conv2d_gradi_op = corr.CorrMM_gradInputs
flip_filter = True
is_dnn = False
def test_graph(self):
# define common values first
groups = 3
bottom = np.random.rand(3, 6, 5, 5).astype(theano.config.floatX)
kern = np.random.rand(9, 2, 3, 3).astype(theano.config.floatX)
bottom_sym = T.tensor4('bottom')
kern_sym = T.tensor4('kern')
# grouped convolution graph
conv_group = self.conv2d(num_groups=groups)(bottom_sym, kern_sym)
gconv_func = theano.function([bottom_sym, kern_sym], conv_group, mode=self.mode)
# Graph for the normal hard way
kern_offset = kern_sym.shape[0] // groups
bottom_offset = bottom_sym.shape[1] // groups
split_conv_output = [self.conv2d()(bottom_sym[:, i * bottom_offset:(i + 1) * bottom_offset, :, :],
kern_sym[i * kern_offset:(i + 1) * kern_offset, :, :, :])
for i in range(groups)]
concatenated_output = T.concatenate(split_conv_output, axis=1)
conv_func = theano.function([bottom_sym, kern_sym], concatenated_output, mode=self.mode)
# calculate outputs for each graph
gconv_output = gconv_func(bottom, kern)
conv_output = conv_func(bottom, kern)
# compare values
utt.assert_allclose(gconv_output, conv_output)
if __name__ == '__main__': if __name__ == '__main__':
t = TestCorr2D('setUp') t = TestCorr2D('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论