提交 e686253f authored 作者: Frederic's avatar Frederic

fix gh-883. Fix a crash introduced in the trunk for reshape([-1])

上级 3474a426
...@@ -5391,11 +5391,13 @@ class Reshape(Op): ...@@ -5391,11 +5391,13 @@ class Reshape(Op):
requ = list(requ.data) requ = list(requ.data)
requ_part = [ele for ele in requ if ele != -1] requ_part = [ele for ele in requ if ele != -1]
crit = len(requ) - len(requ_part) crit = len(requ) - len(requ_part)
if crit == 1: if crit == 1 and len(requ_part) > 0:
missing = numpy.prod(ishapes[0]) / numpy.prod(requ_part) missing = mul(*ishapes[0]) / mul(*requ_part)
for i, ele in enumerate(requ): for i, ele in enumerate(requ):
if ele == -1: if ele == -1:
requ[i] = missing requ[i] = missing
elif crit == 1: # we reshape to -1
requ = [mul(*ishapes[0])]
elif crit > 1: elif crit > 1:
raise ValueError('shape argument to Reshape.perform' raise ValueError('shape argument to Reshape.perform'
' must have at most one entry equal to -1') ' must have at most one entry equal to -1')
......
...@@ -6686,8 +6686,17 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6686,8 +6686,17 @@ class TestInferShape(utt.InferShapeTester):
# (non-constant) input shape # (non-constant) input shape
admat = dmatrix() admat = dmatrix()
aivec = ivector() aivec = ivector()
ndim = 2 ndim = 1
admat_val = rand(3, 4) admat_val = rand(3, 4)
self._compile_and_check([admat],
[Reshape(ndim)(admat, [12])],
[admat_val], Reshape)
self._compile_and_check([admat],
[Reshape(ndim)(admat, [-1])],
[admat_val], Reshape)
ndim = 2
self._compile_and_check([admat], self._compile_and_check([admat],
[Reshape(ndim)(admat, [4, 3])], [Reshape(ndim)(admat, [4, 3])],
[admat_val], Reshape) [admat_val], Reshape)
...@@ -6696,6 +6705,17 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6696,6 +6705,17 @@ class TestInferShape(utt.InferShapeTester):
[Reshape(ndim)(admat, [4, -1])], [Reshape(ndim)(admat, [4, -1])],
[admat_val], Reshape) [admat_val], Reshape)
self._compile_and_check([admat],
[Reshape(ndim)(admat, [3, -1])],
[admat_val], Reshape)
self._compile_and_check([admat],
[Reshape(ndim)(admat, [-1, 3])],
[admat_val], Reshape)
self._compile_and_check([admat],
[Reshape(ndim)(admat, [-1, 4])],
[admat_val], Reshape)
# enable when infer_shape is generalized: # enable when infer_shape is generalized:
# self._compile_and_check([admat, aivec], # self._compile_and_check([admat, aivec],
# [Reshape(ndim)(admat, aivec)], # [Reshape(ndim)(admat, aivec)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论