提交 b92db391 authored 作者: affanv14's avatar affanv14

add numgroups to conv3d grad functions

上级 c776e6fa
...@@ -912,7 +912,8 @@ def conv3d_grad_wrt_inputs(output_grad, ...@@ -912,7 +912,8 @@ def conv3d_grad_wrt_inputs(output_grad,
border_mode='valid', border_mode='valid',
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its inputs """Compute conv output gradient w.r.t its inputs
This function builds the symbolic graph for getting the This function builds the symbolic graph for getting the
...@@ -1035,7 +1036,8 @@ def conv3d_grad_wrt_inputs(output_grad, ...@@ -1035,7 +1036,8 @@ def conv3d_grad_wrt_inputs(output_grad,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return grad_input_op(filters, output_grad, input_shape[-3:]) return grad_input_op(filters, output_grad, input_shape[-3:])
...@@ -1179,7 +1181,8 @@ def conv3d_grad_wrt_weights(input, ...@@ -1179,7 +1181,8 @@ def conv3d_grad_wrt_weights(input,
border_mode='valid', border_mode='valid',
subsample=(1, 1, 1), subsample=(1, 1, 1),
filter_flip=True, filter_flip=True,
filter_dilation=(1, 1, 1)): filter_dilation=(1, 1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its weights """Compute conv output gradient w.r.t its weights
This function will build the symbolic graph for getting the This function will build the symbolic graph for getting the
...@@ -1293,7 +1296,8 @@ def conv3d_grad_wrt_weights(input, ...@@ -1293,7 +1296,8 @@ def conv3d_grad_wrt_weights(input,
border_mode=border_mode, border_mode=border_mode,
subsample=subsample, subsample=subsample,
filter_flip=filter_flip, filter_flip=filter_flip,
filter_dilation=filter_dilation) filter_dilation=filter_dilation,
num_groups=num_groups)
return gradWeight_op(input, output_grad, filter_shape[-3:]) return gradWeight_op(input, output_grad, filter_shape[-3:])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论