提交 95450282 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

change make_node to support num_groups

上级 27c60ef3
...@@ -688,6 +688,10 @@ class CorrMM_gradInputs(BaseCorrMM): ...@@ -688,6 +688,10 @@ class CorrMM_gradInputs(BaseCorrMM):
height_width = [as_tensor_variable(shape[0]).astype('int64'), height_width = [as_tensor_variable(shape[0]).astype('int64'),
as_tensor_variable(shape[1]).astype('int64')] as_tensor_variable(shape[1]).astype('int64')]
if self.num_groups > 1:
broadcastable = [topgrad.type.broadcastable[0], False,
False, False]
else:
broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1], broadcastable = [topgrad.type.broadcastable[0], kern.type.broadcastable[1],
False, False] False, False]
dtype = kern.type.dtype dtype = kern.type.dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论