提交 0d7d09a1 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

change get_conv_grad* functions to support num_groups

上级 0b32b7e3
......@@ -138,7 +138,8 @@ def get_conv_shape_1axis(image_shape, kernel_shape, border_mode,
def get_conv_gradweights_shape(image_shape, top_shape,
border_mode, subsample,
filter_dilation=None):
filter_dilation=None,
num_groups=1):
"""
This function tries to compute the kernel shape of convolution gradWeights.
......@@ -166,6 +167,8 @@ def get_conv_gradweights_shape(image_shape, top_shape,
filter_dilation: tuple of int (symbolic or numeric). Its two or three
elements correspond respectively to the dilation on height and
width axis.
num_groups: An int which specifies the number of separate groups to
be divided into.
Returns
-------
......@@ -180,6 +183,8 @@ def get_conv_gradweights_shape(image_shape, top_shape,
if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int')
if num_groups > 1 and len(subsample) == 2:
nchan = nchan // num_groups
if isinstance(border_mode, tuple):
out_shp = tuple(get_conv_gradweights_shape_1axis(
......@@ -244,7 +249,8 @@ def get_conv_gradweights_shape_1axis(image_shape, top_shape, border_mode,
def get_conv_gradinputs_shape(kernel_shape, top_shape,
border_mode, subsample,
filter_dilation=None):
filter_dilation=None,
num_groups=1):
"""
This function tries to compute the image shape of convolution gradInputs.
......@@ -272,6 +278,8 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape,
filter_dilation: tuple of int (symbolic or numeric). Its two or three
elements correspond respectively to the dilation on height and
width axis.
num_groups: An int which specifies the number of separate groups to
be divided into.
Returns
-------
......@@ -285,6 +293,8 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape,
if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int')
if num_groups > 1 and len(subsample) == 2:
nkern = nkern * num_groups
if isinstance(border_mode, tuple):
out_shp = tuple(get_conv_gradinputs_shape_1axis(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论