提交 65b172b8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Unit tests for InversePermutation and ReorderRowElements Ops.

上级 e9ecfe71
......@@ -1784,12 +1784,104 @@ def test_flatten_outdim_invalid():
assert False
except ValueError:
pass
# TODO: write test case for Tile Op
def test_tile():
print >> sys.stderr, "WARNING: No testcase for Tile"
pass
class TestInversePermutation(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_dim1(self):
p = ivector()
inv = inverse_permutation(p)
f_inverse = function([p], inv)
rng = numpy.random.RandomState(utt.fetch_seed())
p_val = rng.permutation(10)
inv_val = f_inverse(p_val)
assert numpy.all(f_inverse(inv_val) == p_val)
assert numpy.all(p_val[inv_val] == numpy.arange(10))
assert numpy.all(inv_val[p_val] == numpy.arange(10))
def test_dim2(self):
p = imatrix()
inv = inverse_permutation(p)
f_inverse = function([p], inv)
rng = numpy.random.RandomState(utt.fetch_seed())
p_val = numpy.asarray([rng.permutation(10) for i in range(7)])
inv_val = f_inverse(p_val)
assert numpy.all(f_inverse(inv_val) == p_val)
for p_row, i_row in zip(p_val, inv_val):
assert numpy.all(p_row[i_row] == numpy.arange(10))
assert numpy.all(i_row[p_row] == numpy.arange(10))
class TestReorderRowElements(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_1_1(self):
input = vector()
p = ivector()
out = reorder_row_elements(input, p)
reorder = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(5,))
p_val = rng.permutation(5)
out_val = reorder(input_val, p_val)
out_bis = input_val[p_val]
assert numpy.all(out_val == out_bis)
# Verify gradient
def reorder_fixed(s_input):
return reorder_row_elements(s_input, p_val)
utt.verify_grad(reorder_fixed, [input_val])
def test_2_1(self):
input = matrix()
p = ivector()
out = reorder_row_elements(input, p)
reorder = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(3,5))
p_val = rng.permutation(5)
out_val = reorder(input_val, p_val)
out_bis = numpy.asarray([row[p_val] for row in input_val])
assert numpy.all(out_val == out_bis)
# Verify gradient
def reorder_fixed(s_input):
return reorder_row_elements(s_input, p_val)
utt.verify_grad(reorder_fixed, [input_val])
def test_2_2(self):
input = matrix()
p = imatrix()
out = reorder_row_elements(input, p)
reorder = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(3,5))
p_val = numpy.asarray([rng.permutation(5) for i in range(3)])
out_val = reorder(input_val, p_val)
out_bis = numpy.asarray([i_row[p_row] for i_row, p_row in zip(input_val, p_val)])
assert numpy.all(out_val == out_bis)
# Verify gradient
def reorder_fixed(s_input):
return reorder_row_elements(s_input, p_val)
utt.verify_grad(reorder_fixed, [input_val])
class test_tensordot(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论