提交 40407ca5 authored 作者: affanv14's avatar affanv14

change infer_shape in corr3dmm for grouped convolutions

上级 de901ea6
......@@ -662,6 +662,7 @@ class Corr3dMM_gradWeights(BaseCorr3dMM):
imshp = input_shape[0]
topshp = input_shape[1]
ssize, imshp = imshp[1], list(imshp[2:])
ssize = ssize // self.num_groups
nkern, topshp = topshp[1], list(topshp[2:])
height_width_depth = node.inputs[-3:]
if ((dH != 1) or (padH == -1)):
......@@ -773,6 +774,7 @@ class Corr3dMM_gradInputs(BaseCorr3dMM):
kshp = input_shape[0]
topshp = input_shape[1]
ssize, kshp = kshp[1], list(kshp[2:])
ssize = ssize * self.num_groups
bsize, topshp = topshp[0], list(topshp[2:])
height_width_depth = node.inputs[-3:]
if padH == -1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论