提交 64dfa93e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Eager optimization for no-op flatten

上级 0195a930
...@@ -3081,6 +3081,10 @@ def flatten(x, ndim=1): ...@@ -3081,6 +3081,10 @@ def flatten(x, ndim=1):
else: else:
dims = (-1,) dims = (-1,)
if len(dims) == _x.ndim:
# Nothing to ravel
return _x
x_reshaped = _x.reshape(dims) x_reshaped = _x.reshape(dims)
shape_kept_dims = _x.type.shape[: ndim - 1] shape_kept_dims = _x.type.shape[: ndim - 1]
bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :]) bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :])
......
...@@ -3867,35 +3867,22 @@ class TestInferShape(utt.InferShapeTester): ...@@ -3867,35 +3867,22 @@ class TestInferShape(utt.InferShapeTester):
def test_Flatten(self): def test_Flatten(self):
atens3 = tensor3() atens3 = tensor3()
atens3_val = random(4, 5, 3) atens3_val = random(4, 5, 3)
for ndim in (3, 2, 1): for ndim in (2, 1):
self._compile_and_check( self._compile_and_check(
[atens3], [atens3],
[flatten(atens3, ndim)], [flatten(atens3, ndim)],
[atens3_val], [atens3_val],
Reshape, Reshape,
excluding=["local_useless_reshape"],
) )
amat = matrix() amat = matrix()
amat_val = random(4, 5) amat_val = random(4, 5)
for ndim in (2, 1):
self._compile_and_check(
[amat],
[flatten(amat, ndim)],
[amat_val],
Reshape,
excluding=["local_useless_reshape"],
)
avec = vector()
avec_val = random(4)
ndim = 1 ndim = 1
self._compile_and_check( self._compile_and_check(
[avec], [amat],
[flatten(avec, ndim)], [flatten(amat, ndim)],
[avec_val], [amat_val],
Reshape, Reshape,
excluding=["local_useless_reshape"],
) )
def test_Eye(self): def test_Eye(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论