提交 95dc982b authored 作者: Frederic's avatar Frederic

fix corner case crash in new Flatten.infer_shape.

上级 db6f4cf2
......@@ -5480,8 +5480,15 @@ class Flatten(Op):
def infer_shape(self, node, in_shapes):
in_shp, = in_shapes
out_shape = (in_shp[:self.outdim - 1] +
(numpy.prod(in_shp[self.outdim - 1:]),))
part1 = in_shp[:self.outdim - 1]
part2 = in_shp[self.outdim - 1:]
# The if is needed as numpy.prod([]) return a float 1.0
# We do not want to force the other dtype to int32/64.
if len(part2) > 1:
part2 = prod(part2, dtype='int64')
else:
part2 = 1
out_shape = (part1 + (part2,))
return [out_shape]
def grad(self, inp, grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论