提交 02545ed5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Specify reshape shape length if unknown

上级 141307f0
......@@ -644,6 +644,8 @@ class Reshape(COp):
x = ptb.as_tensor_variable(x)
shp_orig = shp
shp = ptb.as_tensor_variable(shp, ndim=1)
if shp.type.shape == (None,):
shp = specify_shape(shp, self.ndim)
if not (
shp.dtype in int_dtypes
or (isinstance(shp, TensorConstant) and shp.data.size == 0)
......
......@@ -98,6 +98,7 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
Shape_i,
DimShuffle,
Elemwise,
SpecifyShape,
)
super().setup_method()
......@@ -253,9 +254,7 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
f(a_val, [7, 5])
with pytest.raises(ValueError):
f(a_val, [-1, -1])
with pytest.raises(
ValueError, match=".*Shape argument to Reshape has incorrect length.*"
):
with pytest.raises(AssertionError):
f(a_val, [3, 4, 1])
def test_0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论