提交 465ced65 authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

modify infer_shape in corrmm

上级 79f24d78
......@@ -607,6 +607,7 @@ class CorrMM_gradWeights(BaseCorrMM):
imshp = input_shape[0]
topshp = input_shape[1]
ssize, imshp = imshp[1], list(imshp[2:])
ssize = ssize // self.num_groups
nkern, topshp = topshp[1], list(topshp[2:])
height_width = node.inputs[-2:]
if ((dH != 1) or (padH == -1)):
......@@ -707,6 +708,7 @@ class CorrMM_gradInputs(BaseCorrMM):
kshp = input_shape[0]
topshp = input_shape[1]
ssize, kshp = kshp[1], list(kshp[2:])
ssize = ssize * self.num_groups
bsize, topshp = topshp[0], list(topshp[2:])
height_width = node.inputs[-2:]
if padH == -1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论