提交 6cfe8567 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

modify gradweights and gradinput functions to support num_groups

上级 0d7d09a1
......@@ -648,7 +648,8 @@ def conv2d_grad_wrt_inputs(output_grad,
border_mode='valid',
subsample=(1, 1),
filter_flip=True,
filter_dilation=(1, 1)):
filter_dilation=(1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its inputs
This function builds the symbolic graph for getting the
......@@ -721,6 +722,9 @@ def conv2d_grad_wrt_inputs(output_grad,
filter_dilation : tuple of len 2
The filter dilation used in the forward pass.
Also known as input striding.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns
-------
......@@ -771,7 +775,8 @@ def conv2d_grad_wrt_inputs(output_grad,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation)
filter_dilation=filter_dilation,
num_groups=num_groups)
return grad_input_op(filters, output_grad, input_shape[-2:])
......@@ -918,7 +923,8 @@ def conv2d_grad_wrt_weights(input,
border_mode='valid',
subsample=(1, 1),
filter_flip=True,
filter_dilation=(1, 1)):
filter_dilation=(1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its weights
This function will build the symbolic graph for getting the
......@@ -983,6 +989,9 @@ def conv2d_grad_wrt_weights(input,
filter_dilation : tuple of len 2
The filter dilation used in the forward pass.
Also known as input striding.
num_groups : int
Divides the image, kernel and output tensors into num_groups
separate groups. Each which carry out convolutions separately
Returns
-------
......@@ -1033,7 +1042,8 @@ def conv2d_grad_wrt_weights(input,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation)
filter_dilation=filter_dilation,
num_groups=num_groups)
return gradWeight_op(input, output_grad, filter_shape[-2:])
......@@ -1973,15 +1983,17 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights):
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(weights,
top,
bottom.shape[-2:])
self.filter_dilation,
self.num_groups)(weights,
top,
bottom.shape[-2:])
d_top = AbstractConv2d(self.imshp,
self.kshp,
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(bottom, weights)
self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
......@@ -2230,14 +2242,16 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs):
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(
self.filter_dilation,
self.num_groups)(
bottom, top,
weights.shape[-2:])
d_top = AbstractConv2d(self.imshp, self.kshp,
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(bottom, weights)
self.filter_dilation,
self.num_groups)(bottom, weights)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论