提交 4747cf44 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6267 from affanv14/g3

3D Grouped Convolutions
...@@ -516,13 +516,13 @@ class BaseGpuCorrMM(CGpuKernelBase): ...@@ -516,13 +516,13 @@ class BaseGpuCorrMM(CGpuKernelBase):
# flops for any direction, sampling, padding, and border mode # flops for any direction, sampling, padding, and border mode
inputs, filters = inp inputs, filters = inp
outputs, = outp outputs, = outp
assert inputs[1] == filters[1] assert inputs[1] == (filters[1] * self.num_groups)
# nb mul and add by output pixel # nb mul and add by output pixel
flops = filters[2] * filters[3] * 2 flops = filters[2] * filters[3] * 2
# nb flops by output image # nb flops by output image
flops *= outputs[2] * outputs[3] flops *= outputs[2] * outputs[3]
# nb patch multiplied # nb patch multiplied
flops *= inputs[1] * filters[0] * inputs[0] flops *= inputs[1] * filters[0] * inputs[0] / self.num_groups
return flops return flops
def c_headers(self): def c_headers(self):
...@@ -1067,14 +1067,17 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1067,14 +1067,17 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
Perform subsampling of the output (default: (1, 1, 1)). Perform subsampling of the output (default: (1, 1, 1)).
filter_dilation filter_dilation
Perform subsampling of the input, also known as dilation (default: (1, 1, 1)). Perform subsampling of the input, also known as dilation (default: (1, 1, 1)).
num_groups :
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately (default : 1).
""" """
check_broadcast = False check_broadcast = False
__props__ = ('border_mode', 'subsample', 'filter_dilation') __props__ = ('border_mode', 'subsample', 'filter_dilation', 'num_groups')
_f16_ok = True _f16_ok = True
def __init__(self, border_mode="valid", subsample=(1, 1, 1), def __init__(self, border_mode="valid", subsample=(1, 1, 1),
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1), num_groups=1):
if isinstance(border_mode, integer_types): if isinstance(border_mode, integer_types):
border_mode = (border_mode, border_mode, border_mode) border_mode = (border_mode, border_mode, border_mode)
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
...@@ -1093,6 +1096,9 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1093,6 +1096,9 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
raise ValueError("filter_dilation must have three elements") raise ValueError("filter_dilation must have three elements")
self.subsample = tuple(subsample) self.subsample = tuple(subsample)
self.filter_dilation = tuple(filter_dilation) self.filter_dilation = tuple(filter_dilation)
if num_groups < 1:
raise ValueError("Number of groups should be greater than 0")
self.num_groups = num_groups
CGpuKernelBase.__init__(self, ['c_code/corr3d_gemm.c']) CGpuKernelBase.__init__(self, ['c_code/corr3d_gemm.c'])
@property @property
...@@ -1102,11 +1108,17 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1102,11 +1108,17 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
return (0, 0, 0) return (0, 0, 0)
def __str__(self): def __str__(self):
return '%s{%s, %s, %s}' % ( return '%s{%s, %s, %s, %s}' % (
self.__class__.__name__, self.__class__.__name__,
self.border_mode, self.border_mode,
str(self.subsample), str(self.subsample),
str(self.filter_dilation)) str(self.filter_dilation),
str(self.num_groups))
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def flops(self, inp, outp): def flops(self, inp, outp):
""" """
...@@ -1117,13 +1129,13 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1117,13 +1129,13 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
# flops for any direction, sampling, padding, and border mode # flops for any direction, sampling, padding, and border mode
inputs, filters = inp inputs, filters = inp
outputs, = outp outputs, = outp
assert inputs[1] == filters[1] assert inputs[1] == (filters[1] * self.num_groups)
# nb mul and add by output pixel # nb mul and add by output pixel
flops = filters[2] * filters[3] * filters[4] * 2 flops = filters[2] * filters[3] * filters[4] * 2
# nb flops by output image # nb flops by output image
flops *= outputs[2] * outputs[3] * outputs[4] flops *= outputs[2] * outputs[3] * outputs[4]
# nb patch multiplied # nb patch multiplied
flops *= inputs[1] * filters[0] * inputs[0] flops *= inputs[1] * filters[0] * inputs[0] / self.num_groups
return flops return flops
def c_headers(self): def c_headers(self):
...@@ -1189,6 +1201,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1189,6 +1201,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
""" """
dH, dW, dD = self.subsample dH, dW, dD = self.subsample
dilH, dilW, dilD = self.filter_dilation dilH, dilW, dilD = self.filter_dilation
numgroups = self.num_groups
if self.border_mode == "half": if self.border_mode == "half":
padH = padW = padD = -1 padH = padW = padD = -1
elif self.border_mode == "full": elif self.border_mode == "full":
...@@ -1249,6 +1262,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1249,6 +1262,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
int padH = %(padH)s; int padH = %(padH)s;
int padW = %(padW)s; int padW = %(padW)s;
int padD = %(padD)s; int padD = %(padD)s;
int numgroups = %(numgroups)s;
PyGpuArrayObject * bottom = %(bottom)s; PyGpuArrayObject * bottom = %(bottom)s;
PyGpuArrayObject * weights = %(weights)s; PyGpuArrayObject * weights = %(weights)s;
...@@ -1372,7 +1386,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1372,7 +1386,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
// output is weights: (num_filters, num_channels, height, width, depth) // output is weights: (num_filters, num_channels, height, width, depth)
// height, width and depth: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1 // height, width and depth: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1
out_dim[0] = PyGpuArray_DIMS(top)[1]; out_dim[0] = PyGpuArray_DIMS(top)[1];
out_dim[1] = PyGpuArray_DIMS(bottom)[1]; out_dim[1] = PyGpuArray_DIMS(bottom)[1] / numgroups;
out_dim[2] = kH; // already inferred further above out_dim[2] = kH; // already inferred further above
out_dim[3] = kW; // how convenient out_dim[3] = kW; // how convenient
out_dim[4] = kD; out_dim[4] = kD;
...@@ -1399,7 +1413,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1399,7 +1413,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
// output is bottom: (batchsize, num_channels, height, width, depth) // output is bottom: (batchsize, num_channels, height, width, depth)
// height, width and depth: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad // height, width and depth: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
out_dim[0] = PyGpuArray_DIMS(top)[0]; out_dim[0] = PyGpuArray_DIMS(top)[0];
out_dim[1] = PyGpuArray_DIMS(weights)[1]; out_dim[1] = PyGpuArray_DIMS(weights)[1] * numgroups;
out_dim[2] = (%(height)s != -1) ? %(height)s : (PyGpuArray_DIMS(top)[2] - 1) * dH + (PyGpuArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH; out_dim[2] = (%(height)s != -1) ? %(height)s : (PyGpuArray_DIMS(top)[2] - 1) * dH + (PyGpuArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH;
out_dim[3] = (%(width)s != -1) ? %(width)s : (PyGpuArray_DIMS(top)[3] - 1) * dW + (PyGpuArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW; out_dim[3] = (%(width)s != -1) ? %(width)s : (PyGpuArray_DIMS(top)[3] - 1) * dW + (PyGpuArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW;
out_dim[4] = (%(depth)s != -1) ? %(depth)s : (PyGpuArray_DIMS(top)[4] - 1) * dD + (PyGpuArray_DIMS(weights)[4]-1)*dilD + 1 - 2*padD; out_dim[4] = (%(depth)s != -1) ? %(depth)s : (PyGpuArray_DIMS(top)[4] - 1) * dD + (PyGpuArray_DIMS(weights)[4]-1)*dilD + 1 - 2*padD;
...@@ -1448,7 +1462,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase): ...@@ -1448,7 +1462,7 @@ class BaseGpuCorr3dMM(CGpuKernelBase):
// Call GPU code // Call GPU code
out2 = corr3dMM(%(bottom)s, %(weights)s, %(top)s, direction, out2 = corr3dMM(%(bottom)s, %(weights)s, %(top)s, direction,
dH, dW, dD, dilH, dilW, dilD, padH, padW, padD); dH, dW, dD, dilH, dilW, dilD, padH, padW, padD, numgroups);
if (out2==NULL){ if (out2==NULL){
%(fail)s %(fail)s
} }
...@@ -1484,6 +1498,11 @@ class GpuCorr3dMM(BaseGpuCorr3dMM): ...@@ -1484,6 +1498,11 @@ class GpuCorr3dMM(BaseGpuCorr3dMM):
The filter dilation operation applied to each input image. The filter dilation operation applied to each input image.
Should be a tuple with 3 elements. Should be a tuple with 3 elements.
Set to `(1, 1, 1)` to disable filter dilation. Set to `(1, 1, 1)` to disable filter dilation.
num_groups
The number of distinct groups the image and kernel must be
divided into.
should be an int
set to 1 to disable grouped convolution
Notes Notes
----- -----
...@@ -1503,9 +1522,10 @@ class GpuCorr3dMM(BaseGpuCorr3dMM): ...@@ -1503,9 +1522,10 @@ class GpuCorr3dMM(BaseGpuCorr3dMM):
""" """
def __init__(self, border_mode="valid", def __init__(self, border_mode="valid",
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
super(GpuCorr3dMM, self).__init__(border_mode, subsample, super(GpuCorr3dMM, self).__init__(border_mode, subsample,
filter_dilation) filter_dilation, num_groups)
def make_node(self, img, kern): def make_node(self, img, kern):
ctx_name = infer_context_name(img, kern) ctx_name = infer_context_name(img, kern)
...@@ -1534,11 +1554,13 @@ class GpuCorr3dMM(BaseGpuCorr3dMM): ...@@ -1534,11 +1554,13 @@ class GpuCorr3dMM(BaseGpuCorr3dMM):
top = gpu_contiguous(top) top = gpu_contiguous(top)
d_bottom = GpuCorr3dMM_gradInputs(self.border_mode, d_bottom = GpuCorr3dMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
weights, top, bottom.shape[-3:]) weights, top, bottom.shape[-3:])
d_weights = GpuCorr3dMM_gradWeights(self.border_mode, d_weights = GpuCorr3dMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
bottom, top, weights.shape[-3:]) bottom, top, weights.shape[-3:])
return d_bottom, d_weights return d_bottom, d_weights
...@@ -1556,10 +1578,12 @@ class GpuCorr3dMM_gradWeights(BaseGpuCorr3dMM): ...@@ -1556,10 +1578,12 @@ class GpuCorr3dMM_gradWeights(BaseGpuCorr3dMM):
def __init__(self, border_mode="valid", def __init__(self, border_mode="valid",
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
super(GpuCorr3dMM_gradWeights, self).__init__(border_mode, super(GpuCorr3dMM_gradWeights, self).__init__(border_mode,
subsample, subsample,
filter_dilation) filter_dilation,
num_groups)
def make_node(self, img, topgrad, shape=None): def make_node(self, img, topgrad, shape=None):
ctx_name = infer_context_name(img, topgrad) ctx_name = infer_context_name(img, topgrad)
...@@ -1600,11 +1624,13 @@ class GpuCorr3dMM_gradWeights(BaseGpuCorr3dMM): ...@@ -1600,11 +1624,13 @@ class GpuCorr3dMM_gradWeights(BaseGpuCorr3dMM):
weights = gpu_contiguous(weights) weights = gpu_contiguous(weights)
d_bottom = GpuCorr3dMM_gradInputs(self.border_mode, d_bottom = GpuCorr3dMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(weights, self.filter_dilation,
self.num_groups)(weights,
top, top,
bottom.shape[-3:]) bottom.shape[-3:])
d_top = GpuCorr3dMM( d_top = GpuCorr3dMM(
self.border_mode, self.subsample, self.filter_dilation)(bottom, weights) self.border_mode, self.subsample, self.filter_dilation,
self.num_groups)(bottom, weights)
d_height_width_depth = (theano.gradient.DisconnectedType()(),)\ d_height_width_depth = (theano.gradient.DisconnectedType()(),)\
* 3 if len(inp) == 5 else () * 3 if len(inp) == 5 else ()
return (d_bottom, d_top) + d_height_width_depth return (d_bottom, d_top) + d_height_width_depth
...@@ -1629,9 +1655,10 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM): ...@@ -1629,9 +1655,10 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM):
def __init__(self, border_mode="valid", def __init__(self, border_mode="valid",
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
super(GpuCorr3dMM_gradInputs, self).__init__(border_mode, subsample, super(GpuCorr3dMM_gradInputs, self).__init__(border_mode, subsample,
filter_dilation) filter_dilation, num_groups)
def make_node(self, kern, topgrad, shape=None): def make_node(self, kern, topgrad, shape=None):
ctx_name = infer_context_name(kern, topgrad) ctx_name = infer_context_name(kern, topgrad)
...@@ -1651,6 +1678,10 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM): ...@@ -1651,6 +1678,10 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM):
assert shape[1].ndim == 0 assert shape[1].ndim == 0
assert shape[2].ndim == 0 assert shape[2].ndim == 0
if self.num_groups > 1:
broadcastable = [topgrad.type.broadcastable[0], False,
False, False, False]
else:
broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1], broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1],
False, False, False] False, False, False]
return Apply(self, [kern, topgrad] + height_width_depth, return Apply(self, [kern, topgrad] + height_width_depth,
...@@ -1671,12 +1702,14 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM): ...@@ -1671,12 +1702,14 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM):
bottom = gpu_contiguous(bottom) bottom = gpu_contiguous(bottom)
d_weights = GpuCorr3dMM_gradWeights(self.border_mode, d_weights = GpuCorr3dMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, self.filter_dilation,
self.num_groups)(bottom,
top, top,
weights.shape[-3:]) weights.shape[-3:])
d_top = GpuCorr3dMM(self.border_mode, d_top = GpuCorr3dMM(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
d_height_width_depth = (theano.gradient.DisconnectedType()(),)\ d_height_width_depth = (theano.gradient.DisconnectedType()(),)\
* 3 if len(inp) == 5 else () * 3 if len(inp) == 5 else ()
return (d_weights, d_top) + d_height_width_depth return (d_weights, d_top) + d_height_width_depth
......
...@@ -411,7 +411,8 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -411,7 +411,8 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
const size_t dilD = 1, const size_t dilD = 1,
const size_t padH = 0, const size_t padH = 0,
const size_t padW = 0, const size_t padW = 0,
const size_t padD = 0) const size_t padD = 0,
const size_t numgroups = 1)
{ {
if (PyGpuArray_NDIM(bottom) != 5) if (PyGpuArray_NDIM(bottom) != 5)
{ {
...@@ -479,11 +480,16 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -479,11 +480,16 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
const size_t kH = PyGpuArray_DIMS(weight)[2]; const size_t kH = PyGpuArray_DIMS(weight)[2];
const size_t kW = PyGpuArray_DIMS(weight)[3]; const size_t kW = PyGpuArray_DIMS(weight)[3];
const size_t kD = PyGpuArray_DIMS(weight)[4]; const size_t kD = PyGpuArray_DIMS(weight)[4];
if (nChannels != PyGpuArray_DIMS(weight)[1]) { if (nChannels != PyGpuArray_DIMS(weight)[1] * numgroups) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"GpuCorr3dMM images and kernel must have the same stack size\n"); "GpuCorr3dMM images and kernel must have the same stack size\n");
return NULL; return NULL;
} }
if ((nFilters % numgroups) != 0) {
PyErr_SetString(PyExc_ValueError,
"CorrMM the number of filters must be divisible by the number of groups\n");
return NULL;
}
// implicit dilated filter // implicit dilated filter
const size_t dil_kH = (kH - 1) * dilH + 1; const size_t dil_kH = (kH - 1) * dilH + 1;
const size_t dil_kW = (kW - 1) * dilW + 1; const size_t dil_kW = (kW - 1) * dilW + 1;
...@@ -511,7 +517,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -511,7 +517,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
" weight shape: %ld %ld %ld %ld %ld\n" " weight shape: %ld %ld %ld %ld %ld\n"
" top shape: %ld %ld %ld %ld %ld (expected %ld %ld %ld %ld %ld)\n", " top shape: %ld %ld %ld %ld %ld (expected %ld %ld %ld %ld %ld)\n",
batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth, batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth,
nFilters, nChannels, kH, kW, kD, nFilters, nChannels / numgroups, kH, kW, kD,
PyGpuArray_DIMS(top)[0], PyGpuArray_DIMS(top)[1], PyGpuArray_DIMS(top)[0], PyGpuArray_DIMS(top)[1],
PyGpuArray_DIMS(top)[2], PyGpuArray_DIMS(top)[3], PyGpuArray_DIMS(top)[4], PyGpuArray_DIMS(top)[2], PyGpuArray_DIMS(top)[3], PyGpuArray_DIMS(top)[4],
batchSize, nFilters, topHeight, topWidth, topDepth); batchSize, nFilters, topHeight, topWidth, topDepth);
...@@ -542,11 +548,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -542,11 +548,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
} }
// Define some useful variables // Define some useful variables
const size_t bottom_stride = PyGpuArray_STRIDES(bottom)[0] / gpuarray_get_elsize(bottom->ga.typecode); const size_t batch_bottom_stride = PyGpuArray_STRIDES(bottom)[0] / gpuarray_get_elsize(bottom->ga.typecode);
const size_t top_stride = PyGpuArray_STRIDES(top)[0] / gpuarray_get_elsize(top->ga.typecode); const size_t batch_top_stride = PyGpuArray_STRIDES(top)[0] / gpuarray_get_elsize(top->ga.typecode);
const size_t K_ = col_dim[0]; const size_t group_bottom_stride = (PyGpuArray_STRIDES(bottom)[1] * nChannels / numgroups) / gpuarray_get_elsize(bottom->ga.typecode);
const size_t group_top_stride = (PyGpuArray_STRIDES(top)[1] * nFilters / numgroups) / gpuarray_get_elsize(top->ga.typecode);
const size_t group_weight_stride = (PyGpuArray_STRIDES(weight)[0] * nFilters / numgroups) / gpuarray_get_elsize(weight->ga.typecode);
const size_t K_ = col_dim[0] / numgroups;
const size_t N_ = col_dim[1]; const size_t N_ = col_dim[1];
const size_t M_ = nFilters; const size_t group_col_stride = (K_ * N_);
const size_t M_ = nFilters / numgroups;
PyGpuArrayObject *output; PyGpuArrayObject *output;
if (direction == 0) { // forward pass if (direction == 0) { // forward pass
...@@ -567,20 +579,22 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -567,20 +579,22 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// First, im3d2col // First, im3d2col
err = im3d2col( err = im3d2col(
&bottom->ga, n * bottom_stride, nChannels, bottomHeight, &bottom->ga, n * batch_bottom_stride, nChannels, bottomHeight,
bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD, bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
padH, padW, padD, dH, dW, dD, &col->ga); padH, padW, padD, dH, dW, dD, &col->ga);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
} }
for ( size_t g = 0; g < numgroups; ++g){
// Second, gemm // Second, gemm
err = rgemm(cb_fortran, cb_no_trans, cb_no_trans, err = rgemm(cb_fortran, cb_no_trans, cb_no_trans,
N_, M_, K_, 1, N_, M_, K_, 1,
&col->ga, 0, N_, &col->ga, g * group_col_stride, N_,
&weight->ga, 0, K_, &weight->ga, g * group_weight_stride, K_,
0, 0,
&top->ga, n * top_stride, N_); &top->ga, n * batch_top_stride + g * group_top_stride, N_);
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM forward encountered an error running gemm."); "GpuCorr3dMM forward encountered an error running gemm.");
...@@ -607,7 +621,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -607,7 +621,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// First, im3d2col // First, im3d2col
err = im3d2col( err = im3d2col(
&bottom->ga, n * bottom_stride, nChannels, bottomHeight, &bottom->ga, n * batch_bottom_stride, nChannels, bottomHeight,
bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD, bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
padH, padW, padD, dH, dW, dD, &col->ga); padH, padW, padD, dH, dW, dD, &col->ga);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
...@@ -618,12 +632,14 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -618,12 +632,14 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.) // is faster than setting weight to all zeros before the loop.)
for ( size_t g = 0; g < numgroups; ++g){
err = rgemm(cb_fortran, cb_trans, cb_no_trans, err = rgemm(cb_fortran, cb_trans, cb_no_trans,
K_, M_, N_, 1, K_, M_, N_, 1,
&col->ga, 0, N_, &col->ga, g * group_col_stride, N_,
&top->ga, n * top_stride, N_, &top->ga, n * batch_top_stride + g * group_top_stride, N_,
(n == 0) ? 0 : 1, (n == 0) ? 0 : 1,
&weight->ga, 0, K_); &weight->ga, g * group_weight_stride, K_);
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad weights encountered an error running gemm."); "GpuCorr3dMM grad weights encountered an error running gemm.");
...@@ -658,12 +674,14 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -658,12 +674,14 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
// Iterate over batch // Iterate over batch
for (size_t n = 0; n < batchSize; n++) { for (size_t n = 0; n < batchSize; n++) {
// gemm into columns // gemm into columns
for ( size_t g = 0; g < numgroups; ++g){
err = rgemm(cb_fortran, cb_no_trans, cb_trans, err = rgemm(cb_fortran, cb_no_trans, cb_trans,
N_, K_, M_, 1, N_, K_, M_, 1,
&top->ga, n * top_stride, N_, &top->ga, n * batch_top_stride + g * group_top_stride, N_,
&weight->ga, 0, K_, &weight->ga, g * group_weight_stride, K_,
0, 0,
&col->ga, 0, N_); &col->ga, g * group_col_stride, N_);
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad inputs encountered an error running gemm."); "GpuCorr3dMM grad inputs encountered an error running gemm.");
...@@ -674,7 +692,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom, ...@@ -674,7 +692,7 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
err = col2im3d(&col->ga, nChannels, err = col2im3d(&col->ga, nChannels,
bottomHeight, bottomWidth, bottomDepth, bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, kH, kW, kD, dilH, dilW, dilD, padH, padW, padD,
dH, dW, dD, &bottom->ga, n * bottom_stride); dH, dW, dD, &bottom->ga, n * batch_bottom_stride);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
Py_DECREF(col); Py_DECREF(col);
return NULL; return NULL;
......
...@@ -2790,6 +2790,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs): ...@@ -2790,6 +2790,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
if version(raises=False) < 6000 and op.filter_dilation != (1, 1): if version(raises=False) < 6000 and op.filter_dilation != (1, 1):
return None return None
if op.num_groups > 1:
return None
inp1 = inputs[0] inp1 = inputs[0]
inp2 = inputs[1] inp2 = inputs[1]
...@@ -2839,6 +2841,8 @@ def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs): ...@@ -2839,6 +2841,8 @@ def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs):
if version(raises=False) < 6000 and op.filter_dilation != (1, 1, 1): if version(raises=False) < 6000 and op.filter_dilation != (1, 1, 1):
return None return None
if op.num_groups > 1:
return None
inp1 = inputs[0] inp1 = inputs[0]
inp2 = inputs[1] inp2 = inputs[1]
......
...@@ -1707,7 +1707,8 @@ def local_abstractconv3d_gemm(node): ...@@ -1707,7 +1707,8 @@ def local_abstractconv3d_gemm(node):
border_mode = node.op.border_mode border_mode = node.op.border_mode
subsample = node.op.subsample subsample = node.op.subsample
filter_dilation = node.op.filter_dilation filter_dilation = node.op.filter_dilation
if ((border_mode == 'full') and (subsample == (1, 1, 1))): num_groups = node.op.num_groups
if ((border_mode == 'full') and (subsample == (1, 1, 1)) and num_groups == 1):
if not node.op.filter_flip: if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1, ::-1]
# need to dimshuffle the kernel for full convolution # need to dimshuffle the kernel for full convolution
...@@ -1724,7 +1725,8 @@ def local_abstractconv3d_gemm(node): ...@@ -1724,7 +1725,8 @@ def local_abstractconv3d_gemm(node):
# By default use GpuCorr3dMM # By default use GpuCorr3dMM
rval = GpuCorr3dMM(border_mode, rval = GpuCorr3dMM(border_mode,
subsample, subsample,
filter_dilation)(gpu_contiguous(img), filter_dilation,
num_groups)(gpu_contiguous(img),
gpu_contiguous(kern)) gpu_contiguous(kern))
# call GpuCorr3dMM_gradWeights if good # call GpuCorr3dMM_gradWeights if good
...@@ -1737,7 +1739,8 @@ def local_abstractconv3d_gemm(node): ...@@ -1737,7 +1739,8 @@ def local_abstractconv3d_gemm(node):
(None not in node.op.imshp[-3:]) and (None not in node.op.imshp[-3:]) and
(node.op.kshp is not None) and (node.op.kshp is not None) and
(None not in node.op.kshp) and (None not in node.op.kshp) and
border_mode != "half"): border_mode != "half" and
num_groups == 1):
# we know the kernel and output size # we know the kernel and output size
prod1 = node.op.kshp[0] * node.op.kshp[1] * node.op.kshp[2] prod1 = node.op.kshp[0] * node.op.kshp[1] * node.op.kshp[2]
prod2 = ((node.op.imshp[-3] - node.op.kshp[0] + 1) * prod2 = ((node.op.imshp[-3] - node.op.kshp[0] + 1) *
...@@ -1929,7 +1932,8 @@ def local_abstractconv3d_gradweights_gemm(node): ...@@ -1929,7 +1932,8 @@ def local_abstractconv3d_gradweights_gemm(node):
rval = GpuCorr3dMM_gradWeights(border_mode=node.op.border_mode, rval = GpuCorr3dMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)( filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(
gpu_contiguous(img), gpu_contiguous(topgrad), shape) gpu_contiguous(img), gpu_contiguous(topgrad), shape)
if node.op.filter_flip: if node.op.filter_flip:
rval = rval[:, :, ::-1, ::-1, ::-1] rval = rval[:, :, ::-1, ::-1, ::-1]
...@@ -1999,7 +2003,8 @@ def local_abstractconv3d_gradinputs_gemm(node): ...@@ -1999,7 +2003,8 @@ def local_abstractconv3d_gradinputs_gemm(node):
rval = GpuCorr3dMM_gradInputs(border_mode=node.op.border_mode, rval = GpuCorr3dMM_gradInputs(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)( filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(
gpu_contiguous(kern), gpu_contiguous(topgrad), shape) gpu_contiguous(kern), gpu_contiguous(topgrad), shape)
return [rval] return [rval]
......
...@@ -2292,11 +2292,11 @@ def dconv2di(border_mode, subsample, filter_dilation, num_groups): ...@@ -2292,11 +2292,11 @@ def dconv2di(border_mode, subsample, filter_dilation, num_groups):
class Cudnn_grouped_conv(Grouped_conv_noOptim): class Cudnn_grouped_conv(Grouped_conv_noOptim):
mode = mode_with_gpu mode = mode_with_gpu
conv2d = staticmethod(dconv2d) conv = staticmethod(dconv2d)
conv2d_gradw = staticmethod(dconv2dw) conv_gradw = staticmethod(dconv2dw)
conv2d_gradi = staticmethod(dconv2di) conv_gradi = staticmethod(dconv2di)
conv2d_op = dnn.GpuDnnConv conv_op = dnn.GpuDnnConv
conv2d_gradw_op = dnn.GpuDnnConvGradW conv_gradw_op = dnn.GpuDnnConvGradW
conv2d_gradi_op = dnn.GpuDnnConvGradI conv_gradi_op = dnn.GpuDnnConvGradI
flip_filter = False flip_filter = False
is_dnn = True is_dnn = True
...@@ -224,11 +224,11 @@ class TestCorrMM(unittest.TestCase): ...@@ -224,11 +224,11 @@ class TestCorrMM(unittest.TestCase):
class TestGroupGpuCorr2d(Grouped_conv_noOptim): class TestGroupGpuCorr2d(Grouped_conv_noOptim):
mode = theano.compile.get_mode("FAST_RUN") mode = theano.compile.get_mode("FAST_RUN")
conv2d = GpuCorrMM conv = GpuCorrMM
conv2d_gradw = GpuCorrMM_gradWeights conv_gradw = GpuCorrMM_gradWeights
conv2d_gradi = GpuCorrMM_gradInputs conv_gradi = GpuCorrMM_gradInputs
conv2d_op = GpuCorrMM conv_op = GpuCorrMM
conv2d_gradw_op = GpuCorrMM_gradWeights conv_gradw_op = GpuCorrMM_gradWeights
conv2d_gradi_op = GpuCorrMM_gradInputs conv_gradi_op = GpuCorrMM_gradInputs
flip_filter = True flip_filter = True
is_dnn = False is_dnn = False
...@@ -11,6 +11,7 @@ from theano.tensor.nnet.corr3d import Corr3dMM, Corr3dMM_gradWeights, Corr3dMM_g ...@@ -11,6 +11,7 @@ from theano.tensor.nnet.corr3d import Corr3dMM, Corr3dMM_gradWeights, Corr3dMM_g
from ..type import gpuarray_shared_constructor from ..type import gpuarray_shared_constructor
from ..blas import GpuCorr3dMM, GpuCorr3dMM_gradWeights, GpuCorr3dMM_gradInputs from ..blas import GpuCorr3dMM, GpuCorr3dMM_gradWeights, GpuCorr3dMM_gradInputs
from .config import mode_with_gpu, mode_without_gpu, ref_cast from .config import mode_with_gpu, mode_without_gpu, ref_cast
from theano.tensor.nnet.tests.test_abstract_conv import Grouped_conv3d_noOptim
class TestCorr3dMM(unittest.TestCase): class TestCorr3dMM(unittest.TestCase):
...@@ -218,3 +219,15 @@ class TestCorr3dMM(unittest.TestCase): ...@@ -218,3 +219,15 @@ class TestCorr3dMM(unittest.TestCase):
verify_grad=False) verify_grad=False)
self.run_gradinput(inputs_shape=(1, 1024, 3, 3, 1), self.run_gradinput(inputs_shape=(1, 1024, 3, 3, 1),
filters_shape=(1, 1, 1, 1, 1024)) filters_shape=(1, 1, 1, 1, 1024))
class TestGroupGpuCorr3d(Grouped_conv3d_noOptim):
mode = theano.compile.get_mode("FAST_RUN")
conv = GpuCorr3dMM
conv_gradw = GpuCorr3dMM_gradWeights
conv_gradi = GpuCorr3dMM_gradInputs
conv_op = GpuCorr3dMM
conv_gradw_op = GpuCorr3dMM_gradWeights
conv_gradi_op = GpuCorr3dMM_gradInputs
flip_filter = True
is_dnn = False
...@@ -184,7 +184,6 @@ def get_conv_gradweights_shape(image_shape, top_shape, ...@@ -184,7 +184,6 @@ def get_conv_gradweights_shape(image_shape, top_shape,
if filter_dilation is None: if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int') filter_dilation = np.ones(len(subsample), dtype='int')
if num_groups > 1: if num_groups > 1:
assert len(subsample) == 2
nchan = nchan // num_groups nchan = nchan // num_groups
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
...@@ -295,7 +294,6 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape, ...@@ -295,7 +294,6 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape,
if filter_dilation is None: if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int') filter_dilation = np.ones(len(subsample), dtype='int')
if num_groups > 1: if num_groups > 1:
assert len(subsample) == 2
nkern = nkern * num_groups nkern = nkern * num_groups
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
...@@ -671,7 +669,8 @@ def conv3d(input, ...@@ -671,7 +669,8 @@ def conv3d(input,
border_mode='valid', border_mode='valid',
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
""" """
This function will build the symbolic graph for convolving a mini-batch of a This function will build the symbolic graph for convolving a mini-batch of a
stack of 3D inputs with a set of 3D filters. The implementation is modelled stack of 3D inputs with a set of 3D filters. The implementation is modelled
...@@ -734,6 +733,10 @@ def conv3d(input, ...@@ -734,6 +733,10 @@ def conv3d(input,
Factor by which to subsample (stride) the input. Factor by which to subsample (stride) the input.
Also called dilation elsewhere. Also called dilation elsewhere.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns Returns
------- -------
Symbolic 5D tensor Symbolic 5D tensor
...@@ -759,7 +762,8 @@ def conv3d(input, ...@@ -759,7 +762,8 @@ def conv3d(input,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return conv_op(input, filters) return conv_op(input, filters)
...@@ -910,7 +914,8 @@ def conv3d_grad_wrt_inputs(output_grad, ...@@ -910,7 +914,8 @@ def conv3d_grad_wrt_inputs(output_grad,
border_mode='valid', border_mode='valid',
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its inputs """Compute conv output gradient w.r.t its inputs
This function builds the symbolic graph for getting the This function builds the symbolic graph for getting the
...@@ -983,6 +988,9 @@ def conv3d_grad_wrt_inputs(output_grad, ...@@ -983,6 +988,9 @@ def conv3d_grad_wrt_inputs(output_grad,
filter_dilation : tuple of len 3 filter_dilation : tuple of len 3
The filter dilation used in the forward pass. The filter dilation used in the forward pass.
Also known as input striding. Also known as input striding.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns Returns
------- -------
...@@ -1033,7 +1041,8 @@ def conv3d_grad_wrt_inputs(output_grad, ...@@ -1033,7 +1041,8 @@ def conv3d_grad_wrt_inputs(output_grad,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return grad_input_op(filters, output_grad, input_shape[-3:]) return grad_input_op(filters, output_grad, input_shape[-3:])
...@@ -1177,7 +1186,8 @@ def conv3d_grad_wrt_weights(input, ...@@ -1177,7 +1186,8 @@ def conv3d_grad_wrt_weights(input,
border_mode='valid', border_mode='valid',
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its weights """Compute conv output gradient w.r.t its weights
This function will build the symbolic graph for getting the This function will build the symbolic graph for getting the
...@@ -1241,6 +1251,9 @@ def conv3d_grad_wrt_weights(input, ...@@ -1241,6 +1251,9 @@ def conv3d_grad_wrt_weights(input,
filter_dilation : tuple of len 3 filter_dilation : tuple of len 3
The filter dilation used in the forward pass. The filter dilation used in the forward pass.
Also known as input striding. Also known as input striding.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns Returns
------- -------
...@@ -1291,7 +1304,8 @@ def conv3d_grad_wrt_weights(input, ...@@ -1291,7 +1304,8 @@ def conv3d_grad_wrt_weights(input,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return gradWeight_op(input, output_grad, filter_shape[-3:]) return gradWeight_op(input, output_grad, filter_shape[-3:])
...@@ -1603,8 +1617,6 @@ class BaseAbstractConv(Op): ...@@ -1603,8 +1617,6 @@ class BaseAbstractConv(Op):
self.filter_dilation = tuple(filter_dilation) self.filter_dilation = tuple(filter_dilation)
if num_groups < 1: if num_groups < 1:
raise ValueError("num_groups must have value greater than zero") raise ValueError("num_groups must have value greater than zero")
elif num_groups > 1 and convdim == 3:
raise ValueError("grouped convolution not supported for 3D convolutions")
self.num_groups = num_groups self.num_groups = num_groups
def do_constant_folding(self, node): def do_constant_folding(self, node):
...@@ -1664,7 +1676,6 @@ class BaseAbstractConv(Op): ...@@ -1664,7 +1676,6 @@ class BaseAbstractConv(Op):
tuple(slice(None, None, dilation[i]) for i in range(self.convdim)) tuple(slice(None, None, dilation[i]) for i in range(self.convdim))
] = kern ] = kern
if self.convdim == 2:
if img.shape[1] % self.num_groups != 0: if img.shape[1] % self.num_groups != 0:
raise ValueError( raise ValueError(
'number of input channels must be divible by num_groups') 'number of input channels must be divible by num_groups')
...@@ -1675,11 +1686,13 @@ class BaseAbstractConv(Op): ...@@ -1675,11 +1686,13 @@ class BaseAbstractConv(Op):
raise ValueError( raise ValueError(
'the number of input channels in the kernel should ' 'the number of input channels in the kernel should '
'specify the number of channels of 1 group') 'specify the number of channels of 1 group')
val = _valfrommode(mode)
bval = _bvalfromboundary('fill')
input_channel_offset = img.shape[1] // self.num_groups input_channel_offset = img.shape[1] // self.num_groups
output_channel_offset = kern.shape[0] // self.num_groups output_channel_offset = kern.shape[0] // self.num_groups
if self.convdim == 2:
val = _valfrommode(mode)
bval = _bvalfromboundary('fill')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore', np.ComplexWarning) warnings.simplefilter('ignore', np.ComplexWarning)
for b in xrange(img.shape[0]): for b in xrange(img.shape[0]):
...@@ -1692,11 +1705,12 @@ class BaseAbstractConv(Op): ...@@ -1692,11 +1705,12 @@ class BaseAbstractConv(Op):
im0, ...], 1, val, bval, 0) im0, ...], 1, val, bval, 0)
elif self.convdim == 3: elif self.convdim == 3:
for b in xrange(img.shape[0]): for b in xrange(img.shape[0]):
for n in xrange(kern.shape[0]): for g in xrange(self.num_groups):
for im0 in xrange(img.shape[1]): for n in xrange(output_channel_offset):
out[b, n, ...] += convolve(img[b, im0, ...], for im0 in xrange(input_channel_offset):
dilated_kern[n, im0, ...], out[b, g * output_channel_offset + n, ...] += convolve(img[b, g * input_channel_offset + im0, ...],
mode) dilated_kern[g * output_channel_offset + n,
im0, ...], mode)
else: else:
raise NotImplementedError('only 2D and 3D convolution are implemented') raise NotImplementedError('only 2D and 3D convolution are implemented')
return out return out
...@@ -1888,13 +1902,15 @@ class AbstractConv3d(AbstractConv): ...@@ -1888,13 +1902,15 @@ class AbstractConv3d(AbstractConv):
border_mode="valid", border_mode="valid",
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
super(AbstractConv3d, self).__init__(convdim=3, super(AbstractConv3d, self).__init__(convdim=3,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, weights = inp bottom, weights = inp
...@@ -1903,13 +1919,15 @@ class AbstractConv3d(AbstractConv): ...@@ -1903,13 +1919,15 @@ class AbstractConv3d(AbstractConv):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
weights, top, bottom.shape[-3:]) weights, top, bottom.shape[-3:])
d_weights = AbstractConv3d_gradWeights(self.imshp, self.kshp, d_weights = AbstractConv3d_gradWeights(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
bottom, top, weights.shape[-3:]) bottom, top, weights.shape[-3:])
...@@ -2033,8 +2051,8 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -2033,8 +2051,8 @@ class AbstractConv_gradWeights(BaseAbstractConv):
mshp0 = mat.shape[0] // self.num_groups mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[1] * self.num_groups mshp1 = mat.shape[1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:]) mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
mat = mat.transpose((1, 0, 2, 3, 4)) mat = mat.transpose((1, 0, 2) + tuple(range(3, 3 + self.convdim)))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-2:]) mat = mat.reshape((mshp0, mshp1) + mat.shape[-self.convdim:])
return mat return mat
if self.num_groups > 1: if self.num_groups > 1:
...@@ -2147,13 +2165,15 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights): ...@@ -2147,13 +2165,15 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights):
border_mode="valid", border_mode="valid",
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
super(AbstractConv3d_gradWeights, self).__init__(convdim=3, super(AbstractConv3d_gradWeights, self).__init__(convdim=3,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
bottom, top = inp[:2] bottom, top = inp[:2]
...@@ -2162,7 +2182,8 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights): ...@@ -2162,7 +2182,8 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(weights, self.filter_dilation,
self.num_groups)(weights,
top, top,
bottom.shape[-3:]) bottom.shape[-3:])
d_top = AbstractConv3d(self.imshp, d_top = AbstractConv3d(self.imshp,
...@@ -2170,7 +2191,8 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights): ...@@ -2170,7 +2191,8 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used # Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer # for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable. # that the dimensions are broadcastable.
...@@ -2414,13 +2436,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs): ...@@ -2414,13 +2436,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs):
border_mode="valid", border_mode="valid",
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
super(AbstractConv3d_gradInputs, self).__init__(convdim=3, super(AbstractConv3d_gradInputs, self).__init__(convdim=3,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
def grad(self, inp, grads): def grad(self, inp, grads):
weights, top = inp[:2] weights, top = inp[:2]
...@@ -2429,13 +2453,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs): ...@@ -2429,13 +2453,15 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, top, self.filter_dilation,
self.num_groups)(bottom, top,
weights.shape[-3:]) weights.shape[-3:])
d_top = AbstractConv3d(self.imshp, self.kshp, d_top = AbstractConv3d(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used # Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer # for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable. # that the dimensions are broadcastable.
......
...@@ -127,7 +127,8 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -127,7 +127,8 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
const int dilD = 1, const int dilD = 1,
const int padH = 0, const int padH = 0,
const int padW = 0, const int padW = 0,
const int padD = 0) const int padD = 0,
const int numgroups=1)
{ {
if (PyArray_NDIM(bottom) != 5) if (PyArray_NDIM(bottom) != 5)
{ {
...@@ -178,11 +179,16 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -178,11 +179,16 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
const int kH = PyArray_DIMS(weight)[2]; const int kH = PyArray_DIMS(weight)[2];
const int kW = PyArray_DIMS(weight)[3]; const int kW = PyArray_DIMS(weight)[3];
const int kD = PyArray_DIMS(weight)[4]; const int kD = PyArray_DIMS(weight)[4];
if (nChannels != PyArray_DIMS(weight)[1]) { if (nChannels != PyArray_DIMS(weight)[1] * numgroups) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"Corr3dMM images and kernel must have the same stack size\n"); "Corr3dMM images and kernel must have the same stack size\n");
return NULL; return NULL;
} }
if ((nFilters %% numgroups) != 0) {
PyErr_SetString(PyExc_ValueError,
"CorrMM the number of filters must be divisible by the number of groups\n");
return NULL;
}
// implicit dilated filter // implicit dilated filter
const int dil_kH = (kH - 1) * dilH + 1; const int dil_kH = (kH - 1) * dilH + 1;
const int dil_kW = (kW - 1) * dilW + 1; const int dil_kW = (kW - 1) * dilW + 1;
...@@ -210,7 +216,7 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -210,7 +216,7 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
" weight shape: %%d %%d %%d %%d %%d\n" " weight shape: %%d %%d %%d %%d %%d\n"
" top shape: %%ld %%ld %%ld %%ld %%ld (expected %%d %%d %%d %%d %%d)\n", " top shape: %%ld %%ld %%ld %%ld %%ld (expected %%d %%d %%d %%d %%d)\n",
batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth, batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth,
nFilters, nChannels, kH, kW, kD, nFilters, nChannels / numgroups, kH, kW, kD,
PyArray_DIMS(top)[0], PyArray_DIMS(top)[1], PyArray_DIMS(top)[0], PyArray_DIMS(top)[1],
PyArray_DIMS(top)[2], PyArray_DIMS(top)[3], PyArray_DIMS(top)[4], PyArray_DIMS(top)[2], PyArray_DIMS(top)[3], PyArray_DIMS(top)[4],
batchSize, nFilters, topHeight, topWidth, topDepth); batchSize, nFilters, topHeight, topWidth, topDepth);
...@@ -241,12 +247,16 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -241,12 +247,16 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
} }
// Define some useful variables // Define some useful variables
const int bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f; const int batch_bottom_stride = PyArray_STRIDES(bottom)[0]/%(n_bytes)f;
const int top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f; const int group_bottom_stride = (PyArray_STRIDES(bottom)[1] * nChannels / numgroups)/%(n_bytes)f;
const int K_ = col_dim[1]; const int batch_top_stride = PyArray_STRIDES(top)[0]/%(n_bytes)f;
const int group_top_stride = (PyArray_STRIDES(top)[1] * nFilters / numgroups)/%(n_bytes)f;
const int K_ = col_dim[1] / numgroups;
const int N_ = col_dim[2]; const int N_ = col_dim[2];
const int col_stride = (K_ * N_); const int col_stride = (K_ * N_ * numgroups);
const int M_ = nFilters; const int group_col_stride = (K_ * N_);
const int group_weight_stride = (PyArray_STRIDES(weight)[0] * nFilters / numgroups)/%(n_bytes)f;
const int M_ = nFilters / numgroups;
const %(c_float_type)s one = 1.0; const %(c_float_type)s one = 1.0;
const %(c_float_type)s zero = 0.0; const %(c_float_type)s zero = 0.0;
char NTrans = 'N'; char NTrans = 'N';
...@@ -280,18 +290,21 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -280,18 +290,21 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
// First, im3d2col // First, im3d2col
im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride,
bottomHeight, bottomWidth, bottomDepth, nChannels, bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD, kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
for ( int g = 0; g < numgroups; ++g){
// Second, gemm // Second, gemm
%(gemm)s(&NTrans, &NTrans, %(gemm)s(&NTrans, &NTrans,
&N_, &M_, &K_, &N_, &M_, &K_,
&one, &one,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride, &N_, (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride + g * group_col_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_, (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
&zero, &zero,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_); (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_);
}
} }
// Restore to previous blas threads // Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved); %(blas_set_num_threads)s(blas_threads_saved);
...@@ -300,7 +313,7 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -300,7 +313,7 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
output = weight; output = weight;
npy_intp weight_dim[2]; npy_intp weight_dim[2];
weight_dim[0] = (npy_intp)max_threads; weight_dim[0] = (npy_intp)max_threads;
weight_dim[1] = (npy_intp)(M_ * K_); weight_dim[1] = (npy_intp)(M_ * K_ * numgroups);
PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2, PyArrayObject* local_weight = (PyArrayObject*)PyArray_ZEROS(2,
weight_dim, PyArray_TYPE(weight), 0); weight_dim, PyArray_TYPE(weight), 0);
...@@ -322,10 +335,12 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -322,10 +335,12 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
// First, im2col // First, im2col
im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, im3d2col((%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride,
bottomHeight, bottomWidth, bottomDepth, nChannels, bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD, kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD,
(%(float_type)s*)PyArray_DATA(col)+ tid * col_stride); (%(float_type)s*)PyArray_DATA(col)+ tid * col_stride);
for ( int g = 0; g < numgroups; ++g){
// Second, gemm // Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0 // Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This // for the first iteration and beta = 1 for subsequent ones. (This
...@@ -333,12 +348,13 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -333,12 +348,13 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
%(gemm)s(&Trans, &NTrans, %(gemm)s(&Trans, &NTrans,
&K_, &M_, &N_, &K_, &M_, &N_,
&one, &one,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_, (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_, (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_,
(n == 0) ? &zero : &one, (n == 0) ? &zero : &one,
(%(float_type)s*)PyArray_DATA(local_weight) + (%(float_type)s*)PyArray_DATA(local_weight) + g * group_weight_stride +
tid * weight_dim[1], &K_); tid * weight_dim[1], &K_);
} }
}
// Restore to previous blas threads // Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved); %(blas_set_num_threads)s(blas_threads_saved);
...@@ -370,20 +386,23 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom, ...@@ -370,20 +386,23 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
%(blas_set_num_threads)s(1); %(blas_set_num_threads)s(1);
%(omp_flags)s %(omp_flags)s
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
// gemm into columns
int tid = %(omp_get_thread_num)s; int tid = %(omp_get_thread_num)s;
for ( int g = 0; g < numgroups; ++g){
// gemm into columns
%(gemm)s(&NTrans, &Trans, %(gemm)s(&NTrans, &Trans,
&N_, &K_, &M_, &N_, &K_, &M_,
&one, &one,
(%(float_type)s*)PyArray_DATA(top) + n * top_stride, &N_, (%(float_type)s*)PyArray_DATA(top) + n * batch_top_stride + g * group_top_stride, &N_,
(%(float_type)s*)PyArray_DATA(weight), &K_, (%(float_type)s*)PyArray_DATA(weight) + g * group_weight_stride, &K_,
&zero, &zero,
(%(float_type)s*)PyArray_DATA(col) + tid * col_stride, &N_); (%(float_type)s*)PyArray_DATA(col) + tid * col_stride + g * group_col_stride, &N_);
}
// col2im back to the data // col2im back to the data
col2im3d((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels, col2im3d((%(float_type)s*)PyArray_DATA(col) + tid * col_stride, nChannels,
bottomHeight, bottomWidth, bottomDepth, bottomHeight, bottomWidth, bottomDepth,
kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD, kH, kW, kD, dilH, dilW, dilD, padH, padW, padD, dH, dW, dD,
(%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride); (%(float_type)s*)PyArray_DATA(bottom) + n * batch_bottom_stride);
} }
// Restore to previous blas threads // Restore to previous blas threads
%(blas_set_num_threads)s(blas_threads_saved); %(blas_set_num_threads)s(blas_threads_saved);
......
...@@ -40,9 +40,11 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -40,9 +40,11 @@ class BaseCorr3dMM(gof.OpenMPOp):
Perform subsampling of the output (default: (1, 1, 1)). Perform subsampling of the output (default: (1, 1, 1)).
filter_dilation filter_dilation
Perform dilated correlation (default: (1, 1, 1)) Perform dilated correlation (default: (1, 1, 1))
num_groups
Perform grouped convolutions (default: 1)
""" """
check_broadcast = False check_broadcast = False
__props__ = ('border_mode', 'subsample', 'filter_dilation') __props__ = ('border_mode', 'subsample', 'filter_dilation', 'num_groups')
_direction = None _direction = None
...@@ -51,10 +53,11 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -51,10 +53,11 @@ class BaseCorr3dMM(gof.OpenMPOp):
('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')), # 2 ('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')), # 2
dH=int64, dW=int64, dD=int64, dH=int64, dW=int64, dD=int64,
dilH=int64, dilW=int64, dilD=int64, dilH=int64, dilW=int64, dilD=int64,
padH=int64, padW=int64, padD=int64) padH=int64, padW=int64, padD=int64,
num_groups=int64)
def __init__(self, border_mode="valid", subsample=(1, 1, 1), def __init__(self, border_mode="valid", subsample=(1, 1, 1),
filter_dilation=(1, 1, 1), openmp=None): filter_dilation=(1, 1, 1), openmp=None, num_groups=1):
super(BaseCorr3dMM, self).__init__(openmp=openmp) super(BaseCorr3dMM, self).__init__(openmp=openmp)
if isinstance(border_mode, integer_types): if isinstance(border_mode, integer_types):
if border_mode < 0: if border_mode < 0:
...@@ -82,6 +85,9 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -82,6 +85,9 @@ class BaseCorr3dMM(gof.OpenMPOp):
raise ValueError("filter_dilation must have three elements") raise ValueError("filter_dilation must have three elements")
self.subsample = tuple(subsample) self.subsample = tuple(subsample)
self.filter_dilation = tuple(filter_dilation) self.filter_dilation = tuple(filter_dilation)
if num_groups < 1:
raise ValueError("Number of groups should be greater than 0")
self.num_groups = num_groups
if not theano.config.blas.ldflags: if not theano.config.blas.ldflags:
# Theano will use a NumPy C implementation of [sd]gemm_ instead. # Theano will use a NumPy C implementation of [sd]gemm_ instead.
...@@ -127,11 +133,12 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -127,11 +133,12 @@ class BaseCorr3dMM(gof.OpenMPOp):
padD = property(lambda self: self.pad[2]) padD = property(lambda self: self.pad[2])
def __str__(self): def __str__(self):
return '%s{%s, %s, %s}' % ( return '%s{%s, %s, %s, %s}' % (
self.__class__.__name__, self.__class__.__name__,
self.border_mode, self.border_mode,
str(self.subsample), str(self.subsample),
str(self.filter_dilation)) str(self.filter_dilation),
str(self.num_groups))
@staticmethod @staticmethod
def as_common_dtype(in1, in2): def as_common_dtype(in1, in2):
...@@ -141,6 +148,11 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -141,6 +148,11 @@ class BaseCorr3dMM(gof.OpenMPOp):
dtype = theano.scalar.upcast(in1.dtype, in2.dtype) dtype = theano.scalar.upcast(in1.dtype, in2.dtype)
return in1.astype(dtype), in2.astype(dtype) return in1.astype(dtype), in2.astype(dtype)
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'num_groups'):
self.num_groups = 1
def c_support_code(self): def c_support_code(self):
ccodes = blas_headers.blas_header_text() ccodes = blas_headers.blas_header_text()
if self.blas_type == 'openblas': if self.blas_type == 'openblas':
...@@ -170,7 +182,7 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -170,7 +182,7 @@ class BaseCorr3dMM(gof.OpenMPOp):
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files # raise this whenever modifying any of the support_code_files
return (7, self.openmp, blas_header_version()) return (8, self.openmp, blas_header_version())
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of # REMEMBER TO RAISE c_code_cache_version when changing any of
...@@ -293,6 +305,7 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -293,6 +305,7 @@ class BaseCorr3dMM(gof.OpenMPOp):
int padH = %(params)s->padH; int padH = %(params)s->padH;
int padW = %(params)s->padW; int padW = %(params)s->padW;
int padD = %(params)s->padD; int padD = %(params)s->padD;
int numgroups = %(params)s->num_groups;
PyArrayObject * bottom = %(bottom)s; PyArrayObject * bottom = %(bottom)s;
PyArrayObject * weights = %(weights)s; PyArrayObject * weights = %(weights)s;
...@@ -428,7 +441,7 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -428,7 +441,7 @@ class BaseCorr3dMM(gof.OpenMPOp):
// output is weights: (num_filters, num_channels, height, width, depth) // output is weights: (num_filters, num_channels, height, width, depth)
// height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1 // height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1
out_dim[0] = (npy_intp)PyArray_DIMS(top)[1]; out_dim[0] = (npy_intp)PyArray_DIMS(top)[1];
out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1]; out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1] / numgroups;
out_dim[2] = (npy_intp)kH; // already inferred further above out_dim[2] = (npy_intp)kH; // already inferred further above
out_dim[3] = (npy_intp)kW; // how convenient out_dim[3] = (npy_intp)kW; // how convenient
out_dim[4] = (npy_intp)kD; out_dim[4] = (npy_intp)kD;
...@@ -454,7 +467,7 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -454,7 +467,7 @@ class BaseCorr3dMM(gof.OpenMPOp):
// output is bottom: (batchsize, num_channels, height, width, depth) // output is bottom: (batchsize, num_channels, height, width, depth)
// height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad // height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
out_dim[0] = (npy_intp)PyArray_DIMS(top)[0]; out_dim[0] = (npy_intp)PyArray_DIMS(top)[0];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1]; out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1] * numgroups;
out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH); out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH);
out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW); out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW);
out_dim[4] = (npy_intp)((%(depth)s != -1) ? %(depth)s : (PyArray_DIMS(top)[4] - 1) * dD + (PyArray_DIMS(weights)[4]-1)*dilD + 1 - 2*padD); out_dim[4] = (npy_intp)((%(depth)s != -1) ? %(depth)s : (PyArray_DIMS(top)[4] - 1) * dD + (PyArray_DIMS(weights)[4]-1)*dilD + 1 - 2*padD);
...@@ -516,7 +529,8 @@ class BaseCorr3dMM(gof.OpenMPOp): ...@@ -516,7 +529,8 @@ class BaseCorr3dMM(gof.OpenMPOp):
// Call corr3dMM code // Call corr3dMM code
out2 = corr3dMM(%(bottom)s, %(weights)s, %(top)s, direction, out2 = corr3dMM(%(bottom)s, %(weights)s, %(top)s, direction,
dH, dW, dD, dilH, dilW, dilD, padH, padW, padD); dH, dW, dD, dilH, dilW, dilD, padH, padW, padD,
numgroups);
if (out2==NULL){ if (out2==NULL){
%(fail)s %(fail)s
} }
...@@ -552,7 +566,8 @@ class Corr3dMM(BaseCorr3dMM): ...@@ -552,7 +566,8 @@ class Corr3dMM(BaseCorr3dMM):
The filter dilation operation applied to each input image. The filter dilation operation applied to each input image.
Should be a tuple with 3 elements. Should be a tuple with 3 elements.
Set to `(1, 1, 1)` to disable filter dilation. Set to `(1, 1, 1)` to disable filter dilation.
num_groups
Perform grouped convolutions (default: 1)
""" """
_direction = "forward" _direction = "forward"
...@@ -592,11 +607,13 @@ class Corr3dMM(BaseCorr3dMM): ...@@ -592,11 +607,13 @@ class Corr3dMM(BaseCorr3dMM):
top, = grads top, = grads
d_bottom = Corr3dMM_gradInputs(self.border_mode, d_bottom = Corr3dMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(weights, top, self.filter_dilation,
num_groups=self.num_groups)(weights, top,
bottom.shape[-3:]) bottom.shape[-3:])
d_weights = Corr3dMM_gradWeights(self.border_mode, d_weights = Corr3dMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, top, self.filter_dilation,
num_groups=self.num_groups)(bottom, top,
weights.shape[-3:]) weights.shape[-3:])
return d_bottom, d_weights return d_bottom, d_weights
...@@ -653,6 +670,7 @@ class Corr3dMM_gradWeights(BaseCorr3dMM): ...@@ -653,6 +670,7 @@ class Corr3dMM_gradWeights(BaseCorr3dMM):
imshp = input_shape[0] imshp = input_shape[0]
topshp = input_shape[1] topshp = input_shape[1]
ssize, imshp = imshp[1], list(imshp[2:]) ssize, imshp = imshp[1], list(imshp[2:])
ssize = ssize // self.num_groups
nkern, topshp = topshp[1], list(topshp[2:]) nkern, topshp = topshp[1], list(topshp[2:])
height_width_depth = node.inputs[-3:] height_width_depth = node.inputs[-3:]
if ((dH != 1) or (padH == -1)): if ((dH != 1) or (padH == -1)):
...@@ -691,11 +709,13 @@ class Corr3dMM_gradWeights(BaseCorr3dMM): ...@@ -691,11 +709,13 @@ class Corr3dMM_gradWeights(BaseCorr3dMM):
weights, = grads weights, = grads
d_bottom = Corr3dMM_gradInputs(self.border_mode, d_bottom = Corr3dMM_gradInputs(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(weights, top, self.filter_dilation,
num_groups=self.num_groups)(weights, top,
bottom.shape[-3:]) bottom.shape[-3:])
d_top = Corr3dMM(self.border_mode, d_top = Corr3dMM(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, weights) self.filter_dilation,
num_groups=self.num_groups)(bottom, weights)
d_height_width_depth = ((theano.gradient.DisconnectedType()(),) * 3 d_height_width_depth = ((theano.gradient.DisconnectedType()(),) * 3
if len(inp) == 5 else ()) if len(inp) == 5 else ())
return (d_bottom, d_top) + d_height_width_depth return (d_bottom, d_top) + d_height_width_depth
...@@ -738,6 +758,10 @@ class Corr3dMM_gradInputs(BaseCorr3dMM): ...@@ -738,6 +758,10 @@ class Corr3dMM_gradInputs(BaseCorr3dMM):
as_tensor_variable(shape[1]).astype('int64'), as_tensor_variable(shape[1]).astype('int64'),
as_tensor_variable(shape[2]).astype('int64')] as_tensor_variable(shape[2]).astype('int64')]
if self.num_groups > 1:
broadcastable = [topgrad.type.broadcastable[0], False,
False, False, False]
else:
broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1], broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1],
False, False, False] False, False, False]
dtype = kern.type.dtype dtype = kern.type.dtype
...@@ -758,6 +782,7 @@ class Corr3dMM_gradInputs(BaseCorr3dMM): ...@@ -758,6 +782,7 @@ class Corr3dMM_gradInputs(BaseCorr3dMM):
kshp = input_shape[0] kshp = input_shape[0]
topshp = input_shape[1] topshp = input_shape[1]
ssize, kshp = kshp[1], list(kshp[2:]) ssize, kshp = kshp[1], list(kshp[2:])
ssize = ssize * self.num_groups
bsize, topshp = topshp[0], list(topshp[2:]) bsize, topshp = topshp[0], list(topshp[2:])
height_width_depth = node.inputs[-3:] height_width_depth = node.inputs[-3:]
if padH == -1: if padH == -1:
...@@ -807,12 +832,14 @@ class Corr3dMM_gradInputs(BaseCorr3dMM): ...@@ -807,12 +832,14 @@ class Corr3dMM_gradInputs(BaseCorr3dMM):
bottom, = grads bottom, = grads
d_weights = Corr3dMM_gradWeights(self.border_mode, d_weights = Corr3dMM_gradWeights(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, self.filter_dilation,
num_groups=self.num_groups)(bottom,
top, top,
weights.shape[-3:]) weights.shape[-3:])
d_top = Corr3dMM(self.border_mode, d_top = Corr3dMM(self.border_mode,
self.subsample, self.subsample,
self.filter_dilation)(bottom, weights) self.filter_dilation,
num_groups=self.num_groups)(bottom, weights)
d_height_width_depth = ((theano.gradient.DisconnectedType()(),) * 3 d_height_width_depth = ((theano.gradient.DisconnectedType()(),) * 3
if len(inp) == 5 else ()) if len(inp) == 5 else ())
return (d_weights, d_top) + d_height_width_depth return (d_weights, d_top) + d_height_width_depth
......
...@@ -114,7 +114,8 @@ def local_abstractconv3d_gemm(node): ...@@ -114,7 +114,8 @@ def local_abstractconv3d_gemm(node):
kern = kern[:, :, ::-1, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1, ::-1]
rval = Corr3dMM(border_mode=node.op.border_mode, rval = Corr3dMM(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(img, kern) filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(img, kern)
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(node.outputs[0], rval)
return [rval] return [rval]
...@@ -163,7 +164,8 @@ def local_abstractconv3d_gradweight_gemm(node): ...@@ -163,7 +164,8 @@ def local_abstractconv3d_gradweight_gemm(node):
rval = Corr3dMM_gradWeights(border_mode=node.op.border_mode, rval = Corr3dMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(img, topgrad, shape) filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(img, topgrad, shape)
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(node.outputs[0], rval)
# need to flip the kernel if necessary # need to flip the kernel if necessary
...@@ -219,7 +221,8 @@ def local_abstractconv3d_gradinputs_gemm(node): ...@@ -219,7 +221,8 @@ def local_abstractconv3d_gradinputs_gemm(node):
kern = kern[:, :, ::-1, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1, ::-1]
rval = Corr3dMM_gradInputs(border_mode=node.op.border_mode, rval = Corr3dMM_gradInputs(border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(kern, topgrad, filter_dilation=node.op.filter_dilation,
num_groups=node.op.num_groups)(kern, topgrad,
shape) shape)
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(node.outputs[0], rval)
...@@ -267,6 +270,8 @@ def local_conv3d_cpu(node): ...@@ -267,6 +270,8 @@ def local_conv3d_cpu(node):
return None return None
if node.op.filter_dilation != (1, 1, 1): if node.op.filter_dilation != (1, 1, 1):
return None return None
if node.op.num_groups > 1:
return None
bias = theano.tensor.zeros_like(kern[:, 0, 0, 0, 0]) bias = theano.tensor.zeros_like(kern[:, 0, 0, 0, 0])
...@@ -419,6 +424,8 @@ def local_conv3d_gradweight_cpu(node): ...@@ -419,6 +424,8 @@ def local_conv3d_gradweight_cpu(node):
return None return None
if node.op.filter_dilation != (1, 1, 1): if node.op.filter_dilation != (1, 1, 1):
return None return None
if node.op.num_groups > 1:
return None
# conv3D expects shape (batch, row, column, time, channel) # conv3D expects shape (batch, row, column, time, channel)
img = img.dimshuffle(0, 2, 3, 4, 1) img = img.dimshuffle(0, 2, 3, 4, 1)
...@@ -544,6 +551,8 @@ def local_conv3d_gradinputs_cpu(node): ...@@ -544,6 +551,8 @@ def local_conv3d_gradinputs_cpu(node):
return None return None
if node.op.filter_dilation != (1, 1, 1): if node.op.filter_dilation != (1, 1, 1):
return None return None
if node.op.num_groups > 1:
return None
# need to flip the kernel if necessary (conv3D does not flip) # need to flip the kernel if necessary (conv3D does not flip)
if node.op.filter_flip: if node.op.filter_flip:
......
...@@ -1710,12 +1710,12 @@ class TestConv2dGrads(unittest.TestCase): ...@@ -1710,12 +1710,12 @@ class TestConv2dGrads(unittest.TestCase):
class Grouped_conv_noOptim(unittest.TestCase): class Grouped_conv_noOptim(unittest.TestCase):
conv2d = theano.tensor.nnet.abstract_conv.AbstractConv2d conv = theano.tensor.nnet.abstract_conv.AbstractConv2d
conv2d_gradw = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights conv_gradw = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights
conv2d_gradi = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs conv_gradi = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs
conv2d_op = theano.tensor.nnet.abstract_conv.AbstractConv2d conv_op = theano.tensor.nnet.abstract_conv.AbstractConv2d
conv2d_gradw_op = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights conv_gradw_op = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights
conv2d_gradi_op = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs conv_gradi_op = theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs
mode = theano.Mode(optimizer=None) mode = theano.Mode(optimizer=None)
flip_filter = False flip_filter = False
is_dnn = False is_dnn = False
...@@ -1729,32 +1729,45 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1729,32 +1729,45 @@ class Grouped_conv_noOptim(unittest.TestCase):
self.top_shape = [(5, 6, 3, 3), (4, 6, 3, 3), (3, 4, 3, 1), (2, 4, 5, 3)] self.top_shape = [(5, 6, 3, 3), (4, 6, 3, 3), (3, 4, 3, 1), (2, 4, 5, 3)]
self.filter_dilation = (1, 1) self.filter_dilation = (1, 1)
self.ref_mode = 'FAST_RUN' self.ref_mode = 'FAST_RUN'
self.convdim = 2
self.corr_fwd = conv2d_corr
self.corr_gradw = conv2d_corr_gw
self.corr_gradi = conv2d_corr_gi
if theano.config.cxx == "": if theano.config.cxx == "":
raise SkipTest("CorrMM needs cxx") raise SkipTest("CorrMM needs cxx")
def test_fwd(self): def test_fwd(self):
if self.convdim == 2:
img_sym = theano.tensor.tensor4('img') img_sym = theano.tensor.tensor4('img')
kern_sym = theano.tensor.tensor4('kern') kern_sym = theano.tensor.tensor4('kern')
else:
img_sym = theano.tensor.tensor5('img')
kern_sym = theano.tensor.tensor5('kern')
for imshp, kshp, groups in zip(self.img_shape, self.kern_shape, self.num_groups): for imshp, kshp, groups in zip(self.img_shape, self.kern_shape, self.num_groups):
img = np.random.random(imshp).astype(theano.config.floatX) img = np.random.random(imshp).astype(theano.config.floatX)
kern = np.random.random(kshp).astype(theano.config.floatX) kern = np.random.random(kshp).astype(theano.config.floatX)
split_imgs = np.split(img, groups, axis=1) split_imgs = np.split(img, groups, axis=1)
split_kern = np.split(kern, groups, axis=0) split_kern = np.split(kern, groups, axis=0)
grouped_conv_op = self.conv2d(border_mode=self.border_mode, grouped_conv_op = self.conv(border_mode=self.border_mode,
subsample=self.subsample, subsample=self.subsample,
filter_dilation=self.filter_dilation, filter_dilation=self.filter_dilation,
num_groups=groups) num_groups=groups)
if self.flip_filter: if self.flip_filter:
if self.convdim == 2:
grouped_conv_output = grouped_conv_op(img_sym, kern_sym[:, :, ::-1, ::-1]) grouped_conv_output = grouped_conv_op(img_sym, kern_sym[:, :, ::-1, ::-1])
else:
grouped_conv_output = grouped_conv_op(img_sym, kern_sym[:, :, ::-1, ::-1, ::-1])
else: else:
grouped_conv_output = grouped_conv_op(img_sym, kern_sym) grouped_conv_output = grouped_conv_op(img_sym, kern_sym)
grouped_func = theano.function([img_sym, kern_sym], grouped_conv_output, mode=self.mode) grouped_func = theano.function([img_sym, kern_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d_op) assert any([isinstance(node.op, self.conv_op)
for node in grouped_func.maker.fgraph.toposort()]) for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(img, kern) grouped_output = grouped_func(img, kern)
ref_conv_op = conv2d_corr(img_sym, ref_conv_op = self.corr_fwd(img_sym,
kern_sym, kern_sym,
border_mode=self.border_mode, border_mode=self.border_mode,
subsample=self.subsample, subsample=self.subsample,
...@@ -1773,29 +1786,38 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1773,29 +1786,38 @@ class Grouped_conv_noOptim(unittest.TestCase):
eps=1) eps=1)
def test_gradweights(self): def test_gradweights(self):
if self.convdim == 2:
img_sym = theano.tensor.tensor4('img') img_sym = theano.tensor.tensor4('img')
top_sym = theano.tensor.tensor4('top') top_sym = theano.tensor.tensor4('kern')
else:
img_sym = theano.tensor.tensor5('img')
top_sym = theano.tensor.tensor5('kern')
for imshp, kshp, tshp, groups in zip(self.img_shape, self.kern_shape, self.top_shape, self.num_groups): for imshp, kshp, tshp, groups in zip(self.img_shape, self.kern_shape, self.top_shape, self.num_groups):
img = np.random.random(imshp).astype(theano.config.floatX) img = np.random.random(imshp).astype(theano.config.floatX)
top = np.random.random(tshp).astype(theano.config.floatX) top = np.random.random(tshp).astype(theano.config.floatX)
split_imgs = np.split(img, groups, axis=1) split_imgs = np.split(img, groups, axis=1)
split_top = np.split(top, groups, axis=1) split_top = np.split(top, groups, axis=1)
grouped_convgrad_op = self.conv2d_gradw(border_mode=self.border_mode, grouped_convgrad_op = self.conv_gradw(border_mode=self.border_mode,
subsample=self.subsample, subsample=self.subsample,
filter_dilation=self.filter_dilation, filter_dilation=self.filter_dilation,
num_groups=groups) num_groups=groups)
grouped_conv_output = grouped_convgrad_op(img_sym, grouped_conv_output = grouped_convgrad_op(img_sym,
top_sym, top_sym,
tensor.as_tensor_variable(kshp if self.is_dnn else kshp[-2:])) tensor.as_tensor_variable(
kshp if self.is_dnn
else kshp[-self.convdim:]))
if self.flip_filter: if self.flip_filter:
if self.convdim == 2:
grouped_conv_output = grouped_conv_output[:, :, ::-1, ::-1] grouped_conv_output = grouped_conv_output[:, :, ::-1, ::-1]
else:
grouped_conv_output = grouped_conv_output[:, :, ::-1, ::-1, ::-1]
grouped_func = theano.function([img_sym, top_sym], grouped_conv_output, mode=self.mode) grouped_func = theano.function([img_sym, top_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d_gradw_op) assert any([isinstance(node.op, self.conv_gradw_op)
for node in grouped_func.maker.fgraph.toposort()]) for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(img, top) grouped_output = grouped_func(img, top)
ref_conv_op = conv2d_corr_gw(img_sym, ref_conv_op = self.corr_gradw(img_sym,
top_sym, top_sym,
kshp, kshp,
border_mode=self.border_mode, border_mode=self.border_mode,
...@@ -1811,37 +1833,50 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1811,37 +1833,50 @@ class Grouped_conv_noOptim(unittest.TestCase):
def conv_gradweight(inputs_val, output_val): def conv_gradweight(inputs_val, output_val):
return grouped_convgrad_op(inputs_val, output_val, return grouped_convgrad_op(inputs_val, output_val,
tensor.as_tensor_variable(kshp if self.is_dnn else kshp[-2:])) tensor.as_tensor_variable(
kshp if self.is_dnn
else kshp[-self.convdim:]))
utt.verify_grad(conv_gradweight, utt.verify_grad(conv_gradweight,
[img, top], [img, top],
mode=self.mode, eps=1) mode=self.mode, eps=1)
def test_gradinputs(self): def test_gradinputs(self):
if self.convdim == 2:
kern_sym = theano.tensor.tensor4('kern') kern_sym = theano.tensor.tensor4('kern')
top_sym = theano.tensor.tensor4('top') top_sym = theano.tensor.tensor4('top')
else:
kern_sym = theano.tensor.tensor5('kern')
top_sym = theano.tensor.tensor5('top')
for imshp, kshp, tshp, groups in zip(self.img_shape, self.kern_shape, self.top_shape, self.num_groups): for imshp, kshp, tshp, groups in zip(self.img_shape, self.kern_shape, self.top_shape, self.num_groups):
kern = np.random.random(kshp).astype(theano.config.floatX) kern = np.random.random(kshp).astype(theano.config.floatX)
top = np.random.random(tshp).astype(theano.config.floatX) top = np.random.random(tshp).astype(theano.config.floatX)
split_kerns = np.split(kern, groups, axis=0) split_kerns = np.split(kern, groups, axis=0)
split_top = np.split(top, groups, axis=1) split_top = np.split(top, groups, axis=1)
grouped_convgrad_op = self.conv2d_gradi(border_mode=self.border_mode, grouped_convgrad_op = self.conv_gradi(border_mode=self.border_mode,
subsample=self.subsample, subsample=self.subsample,
filter_dilation=self.filter_dilation, filter_dilation=self.filter_dilation,
num_groups=groups) num_groups=groups)
if self.flip_filter: if self.flip_filter:
grouped_conv_output = grouped_convgrad_op(kern_sym[:, :, ::-1, ::-1], top_sym, tensor.as_tensor_variable(imshp[-2:])) if self.convdim == 2:
grouped_conv_output = grouped_convgrad_op(kern_sym[:, :, ::-1, ::-1], top_sym,
tensor.as_tensor_variable(imshp[-self.convdim:]))
else:
grouped_conv_output = grouped_convgrad_op(kern_sym[:, :, ::-1, ::-1, ::-1], top_sym,
tensor.as_tensor_variable(imshp[-self.convdim:]))
else: else:
grouped_conv_output = grouped_convgrad_op(kern_sym, grouped_conv_output = grouped_convgrad_op(kern_sym,
top_sym, top_sym,
tensor.as_tensor_variable(imshp if self.is_dnn else imshp[-2:])) tensor.as_tensor_variable(
imshp if self.is_dnn
else imshp[-self.convdim:]))
grouped_func = theano.function([kern_sym, top_sym], grouped_conv_output, mode=self.mode) grouped_func = theano.function([kern_sym, top_sym], grouped_conv_output, mode=self.mode)
assert any([isinstance(node.op, self.conv2d_gradi_op) assert any([isinstance(node.op, self.conv_gradi_op)
for node in grouped_func.maker.fgraph.toposort()]) for node in grouped_func.maker.fgraph.toposort()])
grouped_output = grouped_func(kern, top) grouped_output = grouped_func(kern, top)
ref_conv_op = conv2d_corr_gi(kern_sym, ref_conv_op = self.corr_gradi(kern_sym,
top_sym, top_sym,
imshp, imshp,
border_mode=self.border_mode, border_mode=self.border_mode,
...@@ -1857,13 +1892,43 @@ class Grouped_conv_noOptim(unittest.TestCase): ...@@ -1857,13 +1892,43 @@ class Grouped_conv_noOptim(unittest.TestCase):
def conv_gradinputs(filters_val, output_val): def conv_gradinputs(filters_val, output_val):
return grouped_convgrad_op(filters_val, output_val, return grouped_convgrad_op(filters_val, output_val,
tensor.as_tensor_variable(imshp if self.is_dnn else imshp[-2:])) tensor.as_tensor_variable(
imshp if self.is_dnn
else imshp[-self.convdim:]))
utt.verify_grad(conv_gradinputs, utt.verify_grad(conv_gradinputs,
[kern, top], [kern, top],
mode=self.mode, eps=1) mode=self.mode, eps=1)
class Grouped_conv3d_noOptim(Grouped_conv_noOptim):
conv = theano.tensor.nnet.abstract_conv.AbstractConv3d
conv_gradw = theano.tensor.nnet.abstract_conv.AbstractConv3d_gradWeights
conv_gradi = theano.tensor.nnet.abstract_conv.AbstractConv3d_gradInputs
conv_op = theano.tensor.nnet.abstract_conv.AbstractConv3d
conv_gradw_op = theano.tensor.nnet.abstract_conv.AbstractConv3d_gradWeights
conv_gradi_op = theano.tensor.nnet.abstract_conv.AbstractConv3d_gradInputs
mode = theano.Mode(optimizer=None)
flip_filter = False
is_dnn = False
def setUp(self):
self.num_groups = [3, 2, 4, 4]
self.border_mode = 'valid'
self.subsample = (1, 1, 1)
self.img_shape = [(2, 6, 5, 5, 5), (1, 4, 7, 5, 7), (1, 8, 5, 3, 5), (2, 4, 7, 7, 7)]
self.kern_shape = [(3, 2, 3, 3, 3), (6, 2, 5, 3, 5), (4, 2, 3, 3, 3), (4, 1, 3, 5, 3)]
self.top_shape = [(2, 3, 3, 3, 3), (1, 6, 3, 3, 3), (1, 4, 3, 1, 3), (2, 4, 5, 3, 5)]
self.filter_dilation = (1, 1, 1)
self.ref_mode = 'FAST_RUN'
self.convdim = 3
self.corr_fwd = conv3d_corr
self.corr_gradw = conv3d_corr_gw
self.corr_gradi = conv3d_corr_gi
if theano.config.cxx == "":
raise SkipTest("CorrMM needs cxx")
class Separable_conv(unittest.TestCase): class Separable_conv(unittest.TestCase):
def test_interface(self): def test_interface(self):
......
...@@ -422,12 +422,12 @@ class TestGroupCorr2d(Grouped_conv_noOptim): ...@@ -422,12 +422,12 @@ class TestGroupCorr2d(Grouped_conv_noOptim):
mode = theano.compile.get_mode("FAST_RUN") mode = theano.compile.get_mode("FAST_RUN")
else: else:
mode = None mode = None
conv2d = corr.CorrMM conv = corr.CorrMM
conv2d_gradw = corr.CorrMM_gradWeights conv_gradw = corr.CorrMM_gradWeights
conv2d_gradi = corr.CorrMM_gradInputs conv_gradi = corr.CorrMM_gradInputs
conv2d_op = corr.CorrMM conv_op = corr.CorrMM
conv2d_gradw_op = corr.CorrMM_gradWeights conv_gradw_op = corr.CorrMM_gradWeights
conv2d_gradi_op = corr.CorrMM_gradInputs conv_gradi_op = corr.CorrMM_gradInputs
flip_filter = True flip_filter = True
is_dnn = False is_dnn = False
...@@ -440,13 +440,13 @@ class TestGroupCorr2d(Grouped_conv_noOptim): ...@@ -440,13 +440,13 @@ class TestGroupCorr2d(Grouped_conv_noOptim):
kern_sym = T.tensor4('kern') kern_sym = T.tensor4('kern')
# grouped convolution graph # grouped convolution graph
conv_group = self.conv2d(num_groups=groups)(bottom_sym, kern_sym) conv_group = self.conv(num_groups=groups)(bottom_sym, kern_sym)
gconv_func = theano.function([bottom_sym, kern_sym], conv_group, mode=self.mode) gconv_func = theano.function([bottom_sym, kern_sym], conv_group, mode=self.mode)
# Graph for the normal hard way # Graph for the normal hard way
kern_offset = kern_sym.shape[0] // groups kern_offset = kern_sym.shape[0] // groups
bottom_offset = bottom_sym.shape[1] // groups bottom_offset = bottom_sym.shape[1] // groups
split_conv_output = [self.conv2d()(bottom_sym[:, i * bottom_offset:(i + 1) * bottom_offset, :, :], split_conv_output = [self.conv()(bottom_sym[:, i * bottom_offset:(i + 1) * bottom_offset, :, :],
kern_sym[i * kern_offset:(i + 1) * kern_offset, :, :, :]) kern_sym[i * kern_offset:(i + 1) * kern_offset, :, :, :])
for i in range(groups)] for i in range(groups)]
concatenated_output = T.concatenate(split_conv_output, axis=1) concatenated_output = T.concatenate(split_conv_output, axis=1)
......
...@@ -12,6 +12,7 @@ import theano ...@@ -12,6 +12,7 @@ import theano
import theano.tensor as T import theano.tensor as T
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet import corr3d, conv from theano.tensor.nnet import corr3d, conv
from theano.tensor.nnet.tests.test_abstract_conv import Grouped_conv3d_noOptim
class TestCorr3D(utt.InferShapeTester): class TestCorr3D(utt.InferShapeTester):
...@@ -418,6 +419,21 @@ class TestCorr3D(utt.InferShapeTester): ...@@ -418,6 +419,21 @@ class TestCorr3D(utt.InferShapeTester):
self.validate((3, 1, 7, 5, 5), (2, 1, 2, 3, 3), (2, 1, 1), non_contiguous=True) self.validate((3, 1, 7, 5, 5), (2, 1, 2, 3, 3), (2, 1, 1), non_contiguous=True)
class TestGroupCorr3d(Grouped_conv3d_noOptim):
if theano.config.mode == "FAST_COMPILE":
mode = theano.compile.get_mode("FAST_RUN")
else:
mode = None
conv = corr3d.Corr3dMM
conv_gradw = corr3d.Corr3dMM_gradWeights
conv_gradi = corr3d.Corr3dMM_gradInputs
conv_op = corr3d.Corr3dMM
conv_gradw_op = corr3d.Corr3dMM_gradWeights
conv_gradi_op = corr3d.Corr3dMM_gradInputs
flip_filter = True
is_dnn = False
if __name__ == '__main__': if __name__ == '__main__':
t = TestCorr3D('setUp') t = TestCorr3D('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论