提交 0d7d09a1 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

change get_conv_grad* functions to support num_groups

上级 0b32b7e3
...@@ -138,7 +138,8 @@ def get_conv_shape_1axis(image_shape, kernel_shape, border_mode, ...@@ -138,7 +138,8 @@ def get_conv_shape_1axis(image_shape, kernel_shape, border_mode,
def get_conv_gradweights_shape(image_shape, top_shape, def get_conv_gradweights_shape(image_shape, top_shape,
border_mode, subsample, border_mode, subsample,
filter_dilation=None): filter_dilation=None,
num_groups=1):
""" """
This function tries to compute the kernel shape of convolution gradWeights. This function tries to compute the kernel shape of convolution gradWeights.
...@@ -166,6 +167,8 @@ def get_conv_gradweights_shape(image_shape, top_shape, ...@@ -166,6 +167,8 @@ def get_conv_gradweights_shape(image_shape, top_shape,
filter_dilation: tuple of int (symbolic or numeric). Its two or three filter_dilation: tuple of int (symbolic or numeric). Its two or three
elements correspond respectively to the dilation on height and elements correspond respectively to the dilation on height and
width axis. width axis.
num_groups: An int which specifies the number of separate groups to
be divided into.
Returns Returns
------- -------
...@@ -180,6 +183,8 @@ def get_conv_gradweights_shape(image_shape, top_shape, ...@@ -180,6 +183,8 @@ def get_conv_gradweights_shape(image_shape, top_shape,
if filter_dilation is None: if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int') filter_dilation = np.ones(len(subsample), dtype='int')
if num_groups > 1 and len(subsample) == 2:
nchan = nchan // num_groups
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
out_shp = tuple(get_conv_gradweights_shape_1axis( out_shp = tuple(get_conv_gradweights_shape_1axis(
...@@ -244,7 +249,8 @@ def get_conv_gradweights_shape_1axis(image_shape, top_shape, border_mode, ...@@ -244,7 +249,8 @@ def get_conv_gradweights_shape_1axis(image_shape, top_shape, border_mode,
def get_conv_gradinputs_shape(kernel_shape, top_shape, def get_conv_gradinputs_shape(kernel_shape, top_shape,
border_mode, subsample, border_mode, subsample,
filter_dilation=None): filter_dilation=None,
num_groups=1):
""" """
This function tries to compute the image shape of convolution gradInputs. This function tries to compute the image shape of convolution gradInputs.
...@@ -272,6 +278,8 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape, ...@@ -272,6 +278,8 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape,
filter_dilation: tuple of int (symbolic or numeric). Its two or three filter_dilation: tuple of int (symbolic or numeric). Its two or three
elements correspond respectively to the dilation on height and elements correspond respectively to the dilation on height and
width axis. width axis.
num_groups: An int which specifies the number of separate groups to
be divided into.
Returns Returns
------- -------
...@@ -285,6 +293,8 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape, ...@@ -285,6 +293,8 @@ def get_conv_gradinputs_shape(kernel_shape, top_shape,
if filter_dilation is None: if filter_dilation is None:
filter_dilation = np.ones(len(subsample), dtype='int') filter_dilation = np.ones(len(subsample), dtype='int')
if num_groups > 1 and len(subsample) == 2:
nkern = nkern * num_groups
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
out_shp = tuple(get_conv_gradinputs_shape_1axis( out_shp = tuple(get_conv_gradinputs_shape_1axis(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论