提交 84cc3313 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

In inverse_permutation, use same dtype as permutation. Test updated.

上级 7ae82b46
......@@ -3625,7 +3625,10 @@ def inverse_permutation(perm):
"""Computes the inverse of permutations.
Each row of input should contain a permutation of the first integers.
"""
return permute_row_elements(arange(perm.shape[-1]), perm, inverse=True)
return permute_row_elements(
arange(perm.shape[-1], dtype=perm.dtype),
perm,
inverse=True)
#########################
# Advanced indexing
......
......@@ -2759,6 +2759,7 @@ class TestInversePermutation(unittest.TestCase):
"""Test the inversion of one permutation (int vector)"""
p = ivector()
inv = inverse_permutation(p)
assert inv.dtype == p.dtype
f_inverse = function([p], inv)
# Generate a random permutation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论