提交 f84b04d4 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op Flatten

上级 45c53b93
......@@ -5426,6 +5426,12 @@ class Flatten(Op):
(numpy.prod(x.shape[outdim - 1:]),))
out[0] = x.reshape(newshape)
def infer_shape(self, node, in_shapes):
in_shp, = in_shapes
out_shape = (in_shp[:self.outdim - 1] +
(numpy.prod(in_shp[self.outdim - 1:]),))
return [out_shape]
def grad(self, inp, grads):
x, = inp
g_out, = grads
......
......@@ -6047,6 +6047,19 @@ class TestInferShape(utt.InferShapeTester):
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
# Flatten
adtens = tensor3()
adtens_val = rand(4, 5, 3)
outdim = 2
self._compile_and_check([adtens],
[Flatten(outdim)(adtens)],
[adtens_val], Flatten)
outdim = 1
self._compile_and_check([adtens],
[Flatten(outdim)(adtens)],
[adtens_val], Flatten)
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论