提交 141307f0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in local_useless_reshape

上级 947b9409
...@@ -897,7 +897,8 @@ def local_useless_reshape(fgraph, node): ...@@ -897,7 +897,8 @@ def local_useless_reshape(fgraph, node):
if nb_m1 <= 1 and all(shape_match): if nb_m1 <= 1 and all(shape_match):
return [inp] return [inp]
if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1): # There is one missing match, but all other dimensions match
if (nb_m1 == 0) and (shape_match.count(False) == 1):
return [inp] return [inp]
return False return False
......
...@@ -383,6 +383,13 @@ class TestLocalUselessReshape: ...@@ -383,6 +383,13 @@ class TestLocalUselessReshape:
new_out = rewrite_graph(out) new_out = rewrite_graph(out)
assert new_out is out assert new_out is out
# Or if more than one dimension cannot be matched
x = tensor(shape=(None, None, None))
shape = [x.shape[0], 3, 3]
out = reshape(x, shape)
new_out = rewrite_graph(out)
assert new_out is out
class TestLocalReshapeToDimshuffle: class TestLocalReshapeToDimshuffle:
def setup_method(self): def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论