提交 84cc3313 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

In inverse_permutation, use same dtype as permutation. Test updated.

上级 7ae82b46
...@@ -3625,7 +3625,10 @@ def inverse_permutation(perm): ...@@ -3625,7 +3625,10 @@ def inverse_permutation(perm):
"""Computes the inverse of permutations. """Computes the inverse of permutations.
Each row of input should contain a permutation of the first integers. Each row of input should contain a permutation of the first integers.
""" """
return permute_row_elements(arange(perm.shape[-1]), perm, inverse=True) return permute_row_elements(
arange(perm.shape[-1], dtype=perm.dtype),
perm,
inverse=True)
######################### #########################
# Advanced indexing # Advanced indexing
......
...@@ -2759,6 +2759,7 @@ class TestInversePermutation(unittest.TestCase): ...@@ -2759,6 +2759,7 @@ class TestInversePermutation(unittest.TestCase):
"""Test the inversion of one permutation (int vector)""" """Test the inversion of one permutation (int vector)"""
p = ivector() p = ivector()
inv = inverse_permutation(p) inv = inverse_permutation(p)
assert inv.dtype == p.dtype
f_inverse = function([p], inv) f_inverse = function([p], inv)
# Generate a random permutation # Generate a random permutation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论