提交 a0400630 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

PermuteRowElements now support broadcasting of all inputs.

Added test to check for that (gradient included). New definition of inverse_permutation, using PermuteRowElements and ARange Ops instead of InversePermutation Op.
上级 0b20701b
...@@ -2784,7 +2784,9 @@ class InversePermutation(Op): ...@@ -2784,7 +2784,9 @@ class InversePermutation(Op):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return [None] return [None]
inverse_permutation = InversePermutation() #inverse_permutation = InversePermutation()
def inverse_permutation(perm):
return permute_row_elements(arange(perm.shape[-1]), perm, inverse=True)
class PermuteRowElements(Op): class PermuteRowElements(Op):
"""Permute the elements of each row (inner-most dim) of a tensor. """Permute the elements of each row (inner-most dim) of a tensor.
...@@ -2801,7 +2803,11 @@ class PermuteRowElements(Op): ...@@ -2801,7 +2803,11 @@ class PermuteRowElements(Op):
If x.ndim > y.ndim, y will be broadcasted to fit x, then each row (vector) If x.ndim > y.ndim, y will be broadcasted to fit x, then each row (vector)
of x will be reordered according to the corresponding row of y. (This is of x will be reordered according to the corresponding row of y. (This is
a generalization of the first case). a generalization of the first case).
WARNING: x will not be broadcasted to fit y (not implemented yet). If x.ndim = 1, every permutation in y will be applied to x, and the output
will contain all the results.
If x.ndim < y.ndim, x will be broadcasted to fit y, and different
permutations contained in y will be applied to each vector in x. (This is
a generalization of the previous case).
If the "inverse" argument is True, the Op will perform the inverse If the "inverse" argument is True, the Op will perform the inverse
permutation instead. permutation instead.
...@@ -2819,12 +2825,21 @@ class PermuteRowElements(Op): ...@@ -2819,12 +2825,21 @@ class PermuteRowElements(Op):
(inverse.type.dtype.startswith('int') or\ (inverse.type.dtype.startswith('int') or\
inverse.type.dtype.startswith('uint')) inverse.type.dtype.startswith('uint'))
# extend y dimension to match x # Match shapes of x and y
assert x.type.ndim >= y.type.ndim x_dim = x.type.ndim
y = shape_padleft(y, n_ones=(x.type.ndim - y.type.ndim)) y_dim = y.type.ndim
if x_dim > y_dim:
y = shape_padleft(y, n_ones=(x_dim - y_dim))
elif x_dim < y_dim:
x = shape_padleft(x, n_ones=(y_dim - x_dim))
# Compute the broadcastable pattern of the output
out_broadcastable = [xb and yb for xb, yb in zip(x.type.broadcastable, y.type.broadcastable)]
out_type = tensor(dtype = x.type.dtype, broadcastable = out_broadcastable)
inputlist = [x, y, inverse] inputlist = [x, y, inverse]
outputlist = [x.type()] outputlist = [out_type]
return Apply(self, inputlist, outputlist) return Apply(self, inputlist, outputlist)
def _rec_perform(self, node, x, y, inverse, out, curdim): def _rec_perform(self, node, x, y, inverse, out, curdim):
...@@ -2855,10 +2870,14 @@ class PermuteRowElements(Op): ...@@ -2855,10 +2870,14 @@ class PermuteRowElements(Op):
if xs0 == ys0: if xs0 == ys0:
for i in range(xs0): for i in range(xs0):
self._rec_perform(node, x[i], y[i], inverse, out[i], curdim+1) self._rec_perform(node, x[i], y[i], inverse, out[i], curdim+1)
elif node.inputs[1].type.broadcastable[curdim]: elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]:
# Broadcast y # Broadcast y
for i in range(xs0): for i in range(xs0):
self._rec_perform(node, x[i], y[0], inverse, out[i], curdim+1) self._rec_perform(node, x[i], y[0], inverse, out[i], curdim+1)
elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]:
# Broadcast x
for i in range(ys0):
self._rec_perform(node, x[0], y[i], inverse, out[i], curdim+1)
else: else:
raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0)) raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0))
...@@ -2867,15 +2886,49 @@ class PermuteRowElements(Op): ...@@ -2867,15 +2886,49 @@ class PermuteRowElements(Op):
y_s = y.shape y_s = y.shape
assert len(x_s) == len(y_s) assert len(x_s) == len(y_s)
if outs[0] is None or outs[0].shape != x_s: # Make sure the output is big enough
outs[0] = numpy.empty_like(x) out_s = []
for xdim, ydim in zip(x_s, y_s):
if xdim == ydim:
outdim = xdim
elif xdim == 1:
outdim = ydim
elif ydim == 1:
outdim = xdim
else:
raise ValueError('Dimension mismatch: %s, %s' % (xdim, ydim))
out_s.append(outdim)
if outs[0] is None or outs[0].shape != out_s:
outs[0] = numpy.empty(out_s, dtype=x.dtype)
self._rec_perform(node, x, y, inverse, outs[0], curdim=0) self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
def grad(self, (x, y, inverse), (gz,)): def grad(self, (x, y, inverse), (gz,)):
"""If 'inverse' is False (0), apply the inverse of y on gz. # First, compute the gradient wrt the broadcasted x.
Else, apply y on gz.""" # If 'inverse' is False (0), apply the inverse of y on gz.
# Else, apply y on gz.
gx = permute_row_elements(gz, y, eq(inverse, 0)) gx = permute_row_elements(gz, y, eq(inverse, 0))
# If x has been broadcasted along some axes, we need to sum
# the gradient over these axes, but keep the dimension (as
# broadcastable)
broadcasted_dims = [dim for dim in range(gz.type.ndim)\
if x.type.broadcastable[dim] and not gz.type.broadcastable[dim]]
gx = Sum(axis = broadcasted_dims)(gx)
# Sum(...) removed the dimensions in broadcasted_dims,
# so we need to put them back.
newdims = []
i = 0
for dim in range(gz.type.ndim):
if dim in broadcasted_dims:
newdims.append('x')
else:
newdims.append(i)
i += 1
gx = DimShuffle(gx.type.broadcastable, newdims)(gx)
return [gx, None, None] return [gx, None, None]
_permute_row_elements = PermuteRowElements() _permute_row_elements = PermuteRowElements()
......
...@@ -1939,7 +1939,7 @@ class TestInversePermutation(unittest.TestCase): ...@@ -1939,7 +1939,7 @@ class TestInversePermutation(unittest.TestCase):
assert numpy.all(inv_val[p_val] == numpy.arange(10)) assert numpy.all(inv_val[p_val] == numpy.arange(10))
def test_dim2(self): def test_dim2(self):
"""Test the inversion of several permutation at a time""" """Test the inversion of several permutations at a time"""
# Each row of p is a different permutation to inverse # Each row of p is a different permutation to inverse
p = imatrix() p = imatrix()
inv = inverse_permutation(p) inv = inverse_permutation(p)
...@@ -2030,6 +2030,56 @@ class TestPermuteRowElements(unittest.TestCase): ...@@ -2030,6 +2030,56 @@ class TestPermuteRowElements(unittest.TestCase):
return permute_row_elements(s_input, p_val) return permute_row_elements(s_input, p_val)
utt.verify_grad(permute_fixed, [input_val]) utt.verify_grad(permute_fixed, [input_val])
def test_1_2(self):
"""Test PermuteRowElements(vector, matrix)
Different permutations will be applied to the same input vector"""
input = vector()
p = imatrix()
out = permute_row_elements(input, p)
permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(5,))
p_val = numpy.asarray([rng.permutation(5) for i in range(3)])
out_val = permute(input_val, p_val)
# Each row of p contains a permutation to apply to the input vector
out_bis = numpy.asarray([input_val[p_row] for p_row in p_val])
assert numpy.all(out_val == out_bis)
# Verify gradient
def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val"""
return permute_row_elements(s_input, p_val)
utt.verify_grad(permute_fixed, [input_val])
def test_3b_2(self):
"""Test permute_row_elements on a more complex broadcasting pattern:
input.type.broadcastable = (False, True, False),
p.type.broadcastable = (False, False)."""
input = TensorType('float64', (False, True, False))()
p = imatrix()
out = permute_row_elements(input, p)
permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(4,1,5))
p_val = numpy.asarray([rng.permutation(5) for i in range(3)])
out_val = permute(input_val, p_val)
# Each row of p contains a permutation to apply to each row
# of the input tensor
out_bis = numpy.asarray([[in_mat[0,p_row] for p_row in p_val] for in_mat in input_val])
assert numpy.all(out_val == out_bis)
# Verify gradient
def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val"""
return permute_row_elements(s_input, p_val)
utt.verify_grad(permute_fixed, [input_val])
class test_tensordot(unittest.TestCase): class test_tensordot(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论