提交 b92db391 authored 作者: affanv14's avatar affanv14

add numgroups to conv3d grad functions

上级 c776e6fa
......@@ -912,7 +912,8 @@ def conv3d_grad_wrt_inputs(output_grad,
border_mode='valid',
subsample=(1, 1, 1),
filter_flip=True,
filter_dilation=(1, 1, 1)):
filter_dilation=(1, 1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its inputs
This function builds the symbolic graph for getting the
......@@ -1035,7 +1036,8 @@ def conv3d_grad_wrt_inputs(output_grad,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation)
filter_dilation=filter_dilation,
num_groups=num_groups)
return grad_input_op(filters, output_grad, input_shape[-3:])
......@@ -1179,7 +1181,8 @@ def conv3d_grad_wrt_weights(input,
border_mode='valid',
subsample=(1, 1, 1),
filter_flip=True,
filter_dilation=(1, 1, 1)):
filter_dilation=(1, 1, 1),
num_groups=1):
"""Compute conv output gradient w.r.t its weights
This function will build the symbolic graph for getting the
......@@ -1293,7 +1296,8 @@ def conv3d_grad_wrt_weights(input,
border_mode=border_mode,
subsample=subsample,
filter_flip=filter_flip,
filter_dilation=filter_dilation)
filter_dilation=filter_dilation,
num_groups=num_groups)
return gradWeight_op(input, output_grad, filter_shape[-3:])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论