提交 5d568c17 authored 作者: nouiz's avatar nouiz

Merge pull request #1037 from lamblin/fix_flatten_infershape

Fix bug in infer_shape of Flatten
...@@ -5778,13 +5778,20 @@ class Flatten(Op): ...@@ -5778,13 +5778,20 @@ class Flatten(Op):
in_shp, = in_shapes in_shp, = in_shapes
part1 = in_shp[:self.outdim - 1] part1 = in_shp[:self.outdim - 1]
part2 = in_shp[self.outdim - 1:] part2 = in_shp[self.outdim - 1:]
# The if is needed as numpy.prod([]) return a float 1.0
# We do not want to force the other dtype to int32/64.
if len(part2) > 1: if len(part2) > 1:
part2 = prod(part2, dtype='int64') part2 = (prod(part2, dtype='int64'),)
elif len(part2) == 1:
# We do not want to force an upcast of part2 if its length is 1
pass
else: else:
part2 = 1 if len(in_shp) == 0 and self.outdim == 1:
out_shape = (part1 + (part2,)) part2 = (1,)
else:
raise ValueError('invalid output ndimensions (%i) for tensor '
'of rank %i' % (self.outdim, len(in_shp)))
out_shape = (part1 + part2)
return [out_shape] return [out_shape]
def grad(self, inp, grads): def grad(self, inp, grads):
......
...@@ -6505,17 +6505,26 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6505,17 +6505,26 @@ class TestInferShape(utt.InferShapeTester):
[admat_val, adtens4_val], TensorDot) [admat_val, adtens4_val], TensorDot)
# Flatten # Flatten
adtens = tensor3() atens3 = tensor3()
adtens_val = rand(4, 5, 3) atens3_val = rand(4, 5, 3)
outdim = 2 for outdim in (3, 2, 1):
self._compile_and_check([adtens], self._compile_and_check([atens3],
[Flatten(outdim)(adtens)], [Flatten(outdim)(atens3)],
[adtens_val], Flatten) [atens3_val], Flatten)
amat = matrix()
amat_val = rand(4, 5)
for outdim in (2, 1):
self._compile_and_check([amat],
[Flatten(outdim)(amat)],
[amat_val], Flatten)
avec = vector()
avec_val = rand(4)
outdim = 1 outdim = 1
self._compile_and_check([adtens], self._compile_and_check([avec],
[Flatten(outdim)(adtens)], [Flatten(outdim)(avec)],
[adtens_val], Flatten) [avec_val], Flatten)
# Eye # Eye
aiscal = iscalar() aiscal = iscalar()
...@@ -6536,6 +6545,8 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6536,6 +6545,8 @@ class TestInferShape(utt.InferShapeTester):
# Shape # Shape
# 'opt.Makevector' precludes optimizer from disentangling # 'opt.Makevector' precludes optimizer from disentangling
# elements of shape # elements of shape
adtens = tensor3()
adtens_val = rand(4, 5, 3)
self._compile_and_check([adtens], self._compile_and_check([adtens],
[Shape()(adtens)], [Shape()(adtens)],
[adtens_val], (opt.MakeVector, Shape)) [adtens_val], (opt.MakeVector, Shape))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论