提交 8e9f3ff6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Rename ReorderRowElements to PermuteRowElements, to better reflect the

implementation.
上级 d147486a
...@@ -2735,8 +2735,8 @@ class InversePermutation(Op): ...@@ -2735,8 +2735,8 @@ class InversePermutation(Op):
inverse_permutation = InversePermutation() inverse_permutation = InversePermutation()
class ReorderRowElements(Op): class PermuteRowElements(Op):
"""Reorder each row (inner-most dim) of a tensor wrt a permutation. """Permute the elements of each row (inner-most dim) of a tensor.
The permutation argument (y) will be broadcasted to fit x, then each The permutation argument (y) will be broadcasted to fit x, then each
row (vector) of x will be reordered according to the corresponding row row (vector) of x will be reordered according to the corresponding row
...@@ -2784,10 +2784,10 @@ class ReorderRowElements(Op): ...@@ -2784,10 +2784,10 @@ class ReorderRowElements(Op):
self._rec_perform(node, x, y, outs[0], curdim=0) self._rec_perform(node, x, y, outs[0], curdim=0)
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
gx = reorder_row_elements(gz, inverse_permutation(y)) gx = permute_row_elements(gz, inverse_permutation(y))
return [gx, None] return [gx, None]
reorder_row_elements = ReorderRowElements() permute_row_elements = PermuteRowElements()
######################### #########################
......
...@@ -6,7 +6,7 @@ import numpy ...@@ -6,7 +6,7 @@ import numpy
from theano.compile import module, In, Component from theano.compile import module, In, Component
from theano.gof import Container from theano.gof import Container
from theano.tensor import raw_random, reorder_row_elements from theano.tensor import raw_random, permute_row_elements
class RandomStreamsInstance(object): class RandomStreamsInstance(object):
"""RandomStreamsInstance""" """RandomStreamsInstance"""
...@@ -192,7 +192,7 @@ class RandomStreams(Component): ...@@ -192,7 +192,7 @@ class RandomStreams(Component):
def shuffle_row_elements(self, input): def shuffle_row_elements(self, input):
"""Return a variable with every row (rightmost index) shuffled""" """Return a variable with every row (rightmost index) shuffled"""
perm = self.permutation(input.ndim-1, input.shape[:-1], input.shape[-1]) perm = self.permutation(input.ndim-1, input.shape[:-1], input.shape[-1])
shuffled = reorder_row_elements(input, perm) shuffled = permute_row_elements(input, perm)
return shuffled return shuffled
...@@ -1832,65 +1832,65 @@ class TestInversePermutation(unittest.TestCase): ...@@ -1832,65 +1832,65 @@ class TestInversePermutation(unittest.TestCase):
assert numpy.all(i_row[p_row] == numpy.arange(10)) assert numpy.all(i_row[p_row] == numpy.arange(10))
class TestReorderRowElements(unittest.TestCase): class TestPermuteRowElements(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
def test_1_1(self): def test_1_1(self):
"""Test ReorderRowElements(vector, vector)""" """Test PermuteRowElements(vector, vector)"""
input = vector() input = vector()
p = ivector() p = ivector()
out = reorder_row_elements(input, p) out = permute_row_elements(input, p)
reorder = function([input, p], out) permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(5,)) input_val = rng.uniform(size=(5,))
p_val = rng.permutation(5) p_val = rng.permutation(5)
out_val = reorder(input_val, p_val) out_val = permute(input_val, p_val)
# Should be equivalent to advanced indexing # Should be equivalent to advanced indexing
out_bis = input_val[p_val] out_bis = input_val[p_val]
assert numpy.all(out_val == out_bis) assert numpy.all(out_val == out_bis)
# Verify gradient # Verify gradient
def reorder_fixed(s_input): def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val""" """Auxiliary op defined to get rid of gradient wrt p_val"""
return reorder_row_elements(s_input, p_val) return permute_row_elements(s_input, p_val)
utt.verify_grad(reorder_fixed, [input_val]) utt.verify_grad(permute_fixed, [input_val])
def test_2_1(self): def test_2_1(self):
"""Test broadcasting in ReorderRowElements(matrix, vector)""" """Test broadcasting in PermuteRowElements(matrix, vector)"""
input = matrix() input = matrix()
p = ivector() p = ivector()
out = reorder_row_elements(input, p) out = permute_row_elements(input, p)
reorder = function([input, p], out) permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(3,5)) input_val = rng.uniform(size=(3,5))
p_val = rng.permutation(5) p_val = rng.permutation(5)
out_val = reorder(input_val, p_val) out_val = permute(input_val, p_val)
# The same permutation should be applied to every row of the input matrix. # The same permutation should be applied to every row of the input matrix.
out_bis = numpy.asarray([row[p_val] for row in input_val]) out_bis = numpy.asarray([row[p_val] for row in input_val])
assert numpy.all(out_val == out_bis) assert numpy.all(out_val == out_bis)
# Verify gradient # Verify gradient
def reorder_fixed(s_input): def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val""" """Auxiliary op defined to get rid of gradient wrt p_val"""
return reorder_row_elements(s_input, p_val) return permute_row_elements(s_input, p_val)
utt.verify_grad(reorder_fixed, [input_val]) utt.verify_grad(permute_fixed, [input_val])
def test_2_2(self): def test_2_2(self):
"""Test ReorderRowElements(matrix, matrix)""" """Test PermuteRowElements(matrix, matrix)"""
input = matrix() input = matrix()
p = imatrix() p = imatrix()
out = reorder_row_elements(input, p) out = permute_row_elements(input, p)
reorder = function([input, p], out) permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(3,5)) input_val = rng.uniform(size=(3,5))
p_val = numpy.asarray([rng.permutation(5) for i in range(3)]) p_val = numpy.asarray([rng.permutation(5) for i in range(3)])
out_val = reorder(input_val, p_val) out_val = permute(input_val, p_val)
# Each row of p contains a permutation to apply to the corresponding # Each row of p contains a permutation to apply to the corresponding
# row of input # row of input
...@@ -1898,10 +1898,10 @@ class TestReorderRowElements(unittest.TestCase): ...@@ -1898,10 +1898,10 @@ class TestReorderRowElements(unittest.TestCase):
assert numpy.all(out_val == out_bis) assert numpy.all(out_val == out_bis)
# Verify gradient # Verify gradient
def reorder_fixed(s_input): def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val""" """Auxiliary op defined to get rid of gradient wrt p_val"""
return reorder_row_elements(s_input, p_val) return permute_row_elements(s_input, p_val)
utt.verify_grad(reorder_fixed, [input_val]) utt.verify_grad(permute_fixed, [input_val])
class test_tensordot(unittest.TestCase): class test_tensordot(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论