提交 95450282 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

change make_node to support num_groups

上级 27c60ef3
......@@ -688,8 +688,12 @@ class CorrMM_gradInputs(BaseCorrMM):
height_width = [as_tensor_variable(shape[0]).astype('int64'),
as_tensor_variable(shape[1]).astype('int64')]
broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1],
False, False]
if self.num_groups > 1:
broadcastable = [topgrad.type.broadcastable[0], False,
False, False]
else:
broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1],
False, False]
dtype = kern.type.dtype
return Apply(self, [kern, topgrad] + height_width,
[TensorType(dtype, broadcastable)()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论