提交 f8955c9e authored 作者: affanv14's avatar affanv14 提交者: Mohammed Affan

add condition in infer_shape

上级 3c67d37e
......@@ -1945,8 +1945,12 @@ class AbstractConv_gradWeights(BaseAbstractConv):
imshp = input_shapes[0]
topshp = input_shapes[1]
kshp = self.kshp[:] if self.kshp is not None else [None] * (2 + self.convdim)
fallback_kshp = ([topshp[1], imshp[1] // self.num_groups] +
[node.inputs[2][i] for i in range(self.convdim)])
if self.num_groups > 1:
fallback_kshp = ([topshp[1], imshp[1] // self.num_groups] +
[node.inputs[2][i] for i in range(self.convdim)])
else:
fallback_kshp = ([topshp[1], imshp[1]] +
[node.inputs[2][i] for i in range(self.convdim)])
kshp = [fallback_kshp[i] if kshp[i] is None else kshp[i]
for i in range(2 + self.convdim)]
return [kshp]
......@@ -2207,8 +2211,12 @@ class AbstractConv_gradInputs(BaseAbstractConv):
kshp = input_shapes[0]
topshp = input_shapes[1]
imshp = self.imshp[:] if self.imshp is not None else [None] * (2 + self.convdim)
fallback_imshp = ([topshp[0], kshp[1] * self.num_groups] +
[node.inputs[2][i] for i in range(self.convdim)])
if self.num_groups > 1:
fallback_imshp = ([topshp[0], kshp[1] * self.num_groups] +
[node.inputs[2][i] for i in range(self.convdim)])
else:
fallback_imshp = ([topshp[0], kshp[1]] +
[node.inputs[2][i] for i in range(self.convdim)])
imshp = [fallback_imshp[i] if imshp[i] is None else imshp[i]
for i in range(2 + self.convdim)]
return [imshp]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论