提交 28d9d4dc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve static output shape of Reshape

上级 734009ae
......@@ -669,7 +669,7 @@ class Reshape(COp):
assert shp.ndim == 1
if isinstance(shp, TensorConstant):
out_shape = tuple(int(s) if s >= 0 else None for s in shp.data)
out_shape = [int(s) if s >= 0 else None for s in shp.data]
else:
out_shape = [None] * self.ndim
shp_list = shp_orig
......@@ -685,6 +685,29 @@ class Reshape(COp):
except NotScalarConstantError:
pass
# If we only don't know the size of one output dimension,
# but we know all the input dimensions we can deduce it
# This happens often when there is -1 as an input of Reshape
if None not in x.type.shape and out_shape.count(None) == 1:
full_size = np.prod(x.type.shape)
known_size = np.prod([s for s in out_shape if s is not None])
out_shape[out_shape.index(None)] = int(full_size // known_size)
out_shape = tuple(out_shape)
# Run some eager error checks
if len(out_shape) != self.ndim:
raise ValueError(
"Shape argument to Reshape has incorrect length:"
f" {len(out_shape)}, should be {self.ndim}"
)
if None not in x.type.shape and None not in out_shape:
if np.prod(x.type.shape) != np.prod(out_shape):
raise ValueError(
f"Reshape: Input shape {x.type.shape} is incompatible with new shape {out_shape}"
)
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
def perform(self, node, inp, out_):
......
import re
import numpy as np
import pytest
......@@ -353,6 +355,29 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
assert tuple(y_new.shape.eval({i: i_test})) == (4, 25)
assert y_new.eval({i: i_test}).shape == (4, 25)
def test_static_shape(self):
dim = lscalar("dim")
x1 = tensor(shape=(2, 2, None))
x2 = specify_shape(x1, (2, 2, 6))
assert reshape(x1, (6, 2)).type.shape == (6, 2)
assert reshape(x1, (6, -1)).type.shape == (6, None)
assert reshape(x1, (6, dim)).type.shape == (6, None)
assert reshape(x1, (6, dim, 2)).type.shape == (6, None, 2)
assert reshape(x1, (6, 3, 99)).type.shape == (6, 3, 99)
assert reshape(x2, (6, 4)).type.shape == (6, 4)
assert reshape(x2, (6, -1)).type.shape == (6, 4)
assert reshape(x2, (6, dim)).type.shape == (6, 4)
assert reshape(x2, (6, dim, 2)).type.shape == (6, 2, 2)
with pytest.raises(
ValueError,
match=re.escape(
"Reshape: Input shape (2, 2, 6) is incompatible with new shape (6, 3, 99)"
),
):
reshape(x2, (6, 3, 99))
def test_shape_i_hash():
assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论