提交 6cfe8567 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

modify gradweights and gradinput functions to support num_groups

上级 0d7d09a1
...@@ -648,7 +648,8 @@ def conv2d_grad_wrt_inputs(output_grad, ...@@ -648,7 +648,8 @@ def conv2d_grad_wrt_inputs(output_grad,
border_mode='valid', border_mode='valid',
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its inputs """Compute conv output gradient w.r.t its inputs
This function builds the symbolic graph for getting the This function builds the symbolic graph for getting the
...@@ -721,6 +722,9 @@ def conv2d_grad_wrt_inputs(output_grad, ...@@ -721,6 +722,9 @@ def conv2d_grad_wrt_inputs(output_grad,
filter_dilation : tuple of len 2 filter_dilation : tuple of len 2
The filter dilation used in the forward pass. The filter dilation used in the forward pass.
Also known as input striding. Also known as input striding.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns Returns
------- -------
...@@ -771,7 +775,8 @@ def conv2d_grad_wrt_inputs(output_grad, ...@@ -771,7 +775,8 @@ def conv2d_grad_wrt_inputs(output_grad,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return grad_input_op(filters, output_grad, input_shape[-2:]) return grad_input_op(filters, output_grad, input_shape[-2:])
...@@ -918,7 +923,8 @@ def conv2d_grad_wrt_weights(input, ...@@ -918,7 +923,8 @@ def conv2d_grad_wrt_weights(input,
border_mode='valid', border_mode='valid',
subsample=(1, 1), subsample=(1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1)): filter_dilation=(1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its weights """Compute conv output gradient w.r.t its weights
This function will build the symbolic graph for getting the This function will build the symbolic graph for getting the
...@@ -983,6 +989,9 @@ def conv2d_grad_wrt_weights(input, ...@@ -983,6 +989,9 @@ def conv2d_grad_wrt_weights(input,
filter_dilation : tuple of len 2 filter_dilation : tuple of len 2
The filter dilation used in the forward pass. The filter dilation used in the forward pass.
Also known as input striding. Also known as input striding.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns Returns
------- -------
...@@ -1033,7 +1042,8 @@ def conv2d_grad_wrt_weights(input, ...@@ -1033,7 +1042,8 @@ def conv2d_grad_wrt_weights(input,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return gradWeight_op(input, output_grad, filter_shape[-2:]) return gradWeight_op(input, output_grad, filter_shape[-2:])
...@@ -1973,7 +1983,8 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights): ...@@ -1973,7 +1983,8 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(weights, self.filter_dilation,
self.num_groups)(weights,
top, top,
bottom.shape[-2:]) bottom.shape[-2:])
d_top = AbstractConv2d(self.imshp, d_top = AbstractConv2d(self.imshp,
...@@ -1981,7 +1992,8 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights): ...@@ -1981,7 +1992,8 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used # Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer # for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable. # that the dimensions are broadcastable.
...@@ -2230,14 +2242,16 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs): ...@@ -2230,14 +2242,16 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)( self.filter_dilation,
self.num_groups)(
bottom, top, bottom, top,
weights.shape[-2:]) weights.shape[-2:])
d_top = AbstractConv2d(self.imshp, self.kshp, d_top = AbstractConv2d(self.imshp, self.kshp,
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip, self.filter_flip,
self.filter_dilation)(bottom, weights) self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used # Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer # for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable. # that the dimensions are broadcastable.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论