提交 5d568c17 authored 作者: nouiz's avatar nouiz

Merge pull request #1037 from lamblin/fix_flatten_infershape

Fix bug in infer_shape of Flatten
......@@ -5778,13 +5778,20 @@ class Flatten(Op):
in_shp, = in_shapes
part1 = in_shp[:self.outdim - 1]
part2 = in_shp[self.outdim - 1:]
# The if is needed as numpy.prod([]) return a float 1.0
# We do not want to force the other dtype to int32/64.
if len(part2) > 1:
part2 = prod(part2, dtype='int64')
part2 = (prod(part2, dtype='int64'),)
elif len(part2) == 1:
# We do not want to force an upcast of part2 if its length is 1
pass
else:
part2 = 1
out_shape = (part1 + (part2,))
if len(in_shp) == 0 and self.outdim == 1:
part2 = (1,)
else:
raise ValueError('invalid output ndimensions (%i) for tensor '
'of rank %i' % (self.outdim, len(in_shp)))
out_shape = (part1 + part2)
return [out_shape]
def grad(self, inp, grads):
......
......@@ -6505,17 +6505,26 @@ class TestInferShape(utt.InferShapeTester):
[admat_val, adtens4_val], TensorDot)
# Flatten
adtens = tensor3()
adtens_val = rand(4, 5, 3)
outdim = 2
self._compile_and_check([adtens],
[Flatten(outdim)(adtens)],
[adtens_val], Flatten)
atens3 = tensor3()
atens3_val = rand(4, 5, 3)
for outdim in (3, 2, 1):
self._compile_and_check([atens3],
[Flatten(outdim)(atens3)],
[atens3_val], Flatten)
amat = matrix()
amat_val = rand(4, 5)
for outdim in (2, 1):
self._compile_and_check([amat],
[Flatten(outdim)(amat)],
[amat_val], Flatten)
avec = vector()
avec_val = rand(4)
outdim = 1
self._compile_and_check([adtens],
[Flatten(outdim)(adtens)],
[adtens_val], Flatten)
self._compile_and_check([avec],
[Flatten(outdim)(avec)],
[avec_val], Flatten)
# Eye
aiscal = iscalar()
......@@ -6536,6 +6545,8 @@ class TestInferShape(utt.InferShapeTester):
# Shape
# 'opt.Makevector' precludes optimizer from disentangling
# elements of shape
adtens = tensor3()
adtens_val = rand(4, 5, 3)
self._compile_and_check([adtens],
[Shape()(adtens)],
[adtens_val], (opt.MakeVector, Shape))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论