提交 9343f65e authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3549 from lamblin/fix_join_infer_shape

Fix infer_shape of Join when axis < 0
......@@ -3907,6 +3907,14 @@ class Join(Op):
assert shp is not None
assert len(shp) == n_dim
# The joining dimension could be negative, but we need it to be
# in [0, n_dim) in the loop below.
# An axis < -n_dim or >= ndim would be invalid, but this is
# not checked here. An Assert op would be a way of addressing that,
# but it may disrupt optimizations.
join_dim = switch(ge(node.inputs[0], 0),
node.inputs[0],
node.inputs[0] + n_dim)
out_shapes = []
for dim in xrange(n_dim):
# we have to deal with 2 possible cases in here :
......@@ -3924,7 +3932,7 @@ class Join(Op):
for shp in ishapes[2:]:
t_side = t_side + shp[dim]
# return the dimensions found
out_shapes.append(switch(eq(dim, node.inputs[0]),
out_shapes.append(switch(eq(dim, join_dim),
t_side, f_side))
return [tuple(out_shapes)]
......
......@@ -7259,27 +7259,31 @@ class TestInferShape(utt.InferShapeTester):
aivec = ivector()
adtens_val = rand(4, 10, 3)
aivec_val = [2, 5, 3]
self._compile_and_check([adtens, aiscal, aivec],
[Split(3)(adtens, aiscal, aivec)[0]],
[adtens_val, 1, aivec_val], (Split))
for aiscal_val in [1, -2]:
self._compile_and_check(
[adtens, aiscal, aivec],
[Split(3)(adtens, aiscal, aivec)[0]],
[adtens_val, aiscal_val, aivec_val], (Split))
# Join
cdmat = dmatrix()
admat_val = rand(1, 3)
bdmat_val = rand(2, 3)
cdmat_val = rand(4, 3)
aiscal_val = 0
self._compile_and_check([aiscal, admat, bdmat, cdmat],
[Join()(aiscal, admat, bdmat, cdmat)],
[aiscal_val, admat_val, bdmat_val, cdmat_val], Join)
for aiscal_val in [0, -2]:
self._compile_and_check(
[aiscal, admat, bdmat, cdmat],
[Join()(aiscal, admat, bdmat, cdmat)],
[aiscal_val, admat_val, bdmat_val, cdmat_val], Join)
admat_val = rand(4, 1)
bdmat_val = rand(4, 3)
cdmat_val = rand(4, 2)
aiscal_val = 1
self._compile_and_check([aiscal, admat, bdmat, cdmat],
[Join()(aiscal, admat, bdmat, cdmat)],
[aiscal_val, admat_val, bdmat_val, cdmat_val], Join)
for aiscal_val in [-1, 1]:
self._compile_and_check(
[aiscal, admat, bdmat, cdmat],
[Join()(aiscal, admat, bdmat, cdmat)],
[aiscal_val, admat_val, bdmat_val, cdmat_val], Join)
# PermuteRowElements
abool = True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论