提交 02545ed5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Specify reshape shape length if unknown

上级 141307f0
...@@ -644,6 +644,8 @@ class Reshape(COp): ...@@ -644,6 +644,8 @@ class Reshape(COp):
x = ptb.as_tensor_variable(x) x = ptb.as_tensor_variable(x)
shp_orig = shp shp_orig = shp
shp = ptb.as_tensor_variable(shp, ndim=1) shp = ptb.as_tensor_variable(shp, ndim=1)
if shp.type.shape == (None,):
shp = specify_shape(shp, self.ndim)
if not ( if not (
shp.dtype in int_dtypes shp.dtype in int_dtypes
or (isinstance(shp, TensorConstant) and shp.data.size == 0) or (isinstance(shp, TensorConstant) and shp.data.size == 0)
......
...@@ -98,6 +98,7 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -98,6 +98,7 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
Shape_i, Shape_i,
DimShuffle, DimShuffle,
Elemwise, Elemwise,
SpecifyShape,
) )
super().setup_method() super().setup_method()
...@@ -253,9 +254,7 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -253,9 +254,7 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
f(a_val, [7, 5]) f(a_val, [7, 5])
with pytest.raises(ValueError): with pytest.raises(ValueError):
f(a_val, [-1, -1]) f(a_val, [-1, -1])
with pytest.raises( with pytest.raises(AssertionError):
ValueError, match=".*Shape argument to Reshape has incorrect length.*"
):
f(a_val, [3, 4, 1]) f(a_val, [3, 4, 1])
def test_0(self): def test_0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论