ConvOp: add infer_shape, handles case with imshp and kshp known.

上级 ecde9888
......@@ -483,6 +483,23 @@ class ConvOp(Op):
return gof.Apply(self, [_inputs, _kerns], [output])
def infer_shape(self, node, input_shapes):
imshp = input_shapes[0]
kshp = input_shapes[1]
batch_size = imshp[0]
fmo = kshp[0]
if self.imshp is not None and self.kshp is not None:
fmshp = ConvOp.getOutputShape(self.imshp[1:], self.kshp, (self.dx,self.dy), self.out_mode)
outshp = (batch_size,fmo) + tuple(fmshp)
return [outshp]
else:
# Haven't implemented this case. imshp and kshp may be symbollic
# and ConvOp.getOutputShape doesn't handle this. In this case
# we simply let the default function do its work.
return node.env.shape_feature.default_infer_shape(node, ishapes)
def perform(self,node, (img2d, filtersflipped), (z,)):
"""
By default if len(img2d.shape)==3, we
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论