提交 a0400630 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

PermuteRowElements now support broadcasting of all inputs.

Added test to check for that (gradient included). New definition of inverse_permutation, using PermuteRowElements and ARange Ops instead of InversePermutation Op.
上级 0b20701b
......@@ -2784,7 +2784,9 @@ class InversePermutation(Op):
def grad(self, (x,), (gz,)):
return [None]
inverse_permutation = InversePermutation()
#inverse_permutation = InversePermutation()
def inverse_permutation(perm):
return permute_row_elements(arange(perm.shape[-1]), perm, inverse=True)
class PermuteRowElements(Op):
"""Permute the elements of each row (inner-most dim) of a tensor.
......@@ -2801,7 +2803,11 @@ class PermuteRowElements(Op):
If x.ndim > y.ndim, y will be broadcasted to fit x, then each row (vector)
of x will be reordered according to the corresponding row of y. (This is
a generalization of the first case).
WARNING: x will not be broadcasted to fit y (not implemented yet).
If x.ndim = 1, every permutation in y will be applied to x, and the output
will contain all the results.
If x.ndim < y.ndim, x will be broadcasted to fit y, and different
permutations contained in y will be applied to each vector in x. (This is
a generalization of the previous case).
If the "inverse" argument is True, the Op will perform the inverse
permutation instead.
......@@ -2819,12 +2825,21 @@ class PermuteRowElements(Op):
(inverse.type.dtype.startswith('int') or\
inverse.type.dtype.startswith('uint'))
# extend y dimension to match x
assert x.type.ndim >= y.type.ndim
y = shape_padleft(y, n_ones=(x.type.ndim - y.type.ndim))
# Match shapes of x and y
x_dim = x.type.ndim
y_dim = y.type.ndim
if x_dim > y_dim:
y = shape_padleft(y, n_ones=(x_dim - y_dim))
elif x_dim < y_dim:
x = shape_padleft(x, n_ones=(y_dim - x_dim))
# Compute the broadcastable pattern of the output
out_broadcastable = [xb and yb for xb, yb in zip(x.type.broadcastable, y.type.broadcastable)]
out_type = tensor(dtype = x.type.dtype, broadcastable = out_broadcastable)
inputlist = [x, y, inverse]
outputlist = [x.type()]
outputlist = [out_type]
return Apply(self, inputlist, outputlist)
def _rec_perform(self, node, x, y, inverse, out, curdim):
......@@ -2855,10 +2870,14 @@ class PermuteRowElements(Op):
if xs0 == ys0:
for i in range(xs0):
self._rec_perform(node, x[i], y[i], inverse, out[i], curdim+1)
elif node.inputs[1].type.broadcastable[curdim]:
elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]:
# Broadcast y
for i in range(xs0):
self._rec_perform(node, x[i], y[0], inverse, out[i], curdim+1)
elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]:
# Broadcast x
for i in range(ys0):
self._rec_perform(node, x[0], y[i], inverse, out[i], curdim+1)
else:
raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0))
......@@ -2867,15 +2886,49 @@ class PermuteRowElements(Op):
y_s = y.shape
assert len(x_s) == len(y_s)
if outs[0] is None or outs[0].shape != x_s:
outs[0] = numpy.empty_like(x)
# Make sure the output is big enough
out_s = []
for xdim, ydim in zip(x_s, y_s):
if xdim == ydim:
outdim = xdim
elif xdim == 1:
outdim = ydim
elif ydim == 1:
outdim = xdim
else:
raise ValueError('Dimension mismatch: %s, %s' % (xdim, ydim))
out_s.append(outdim)
if outs[0] is None or outs[0].shape != out_s:
outs[0] = numpy.empty(out_s, dtype=x.dtype)
self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
def grad(self, (x, y, inverse), (gz,)):
"""If 'inverse' is False (0), apply the inverse of y on gz.
Else, apply y on gz."""
# First, compute the gradient wrt the broadcasted x.
# If 'inverse' is False (0), apply the inverse of y on gz.
# Else, apply y on gz.
gx = permute_row_elements(gz, y, eq(inverse, 0))
# If x has been broadcasted along some axes, we need to sum
# the gradient over these axes, but keep the dimension (as
# broadcastable)
broadcasted_dims = [dim for dim in range(gz.type.ndim)\
if x.type.broadcastable[dim] and not gz.type.broadcastable[dim]]
gx = Sum(axis = broadcasted_dims)(gx)
# Sum(...) removed the dimensions in broadcasted_dims,
# so we need to put them back.
newdims = []
i = 0
for dim in range(gz.type.ndim):
if dim in broadcasted_dims:
newdims.append('x')
else:
newdims.append(i)
i += 1
gx = DimShuffle(gx.type.broadcastable, newdims)(gx)
return [gx, None, None]
_permute_row_elements = PermuteRowElements()
......
......@@ -1939,7 +1939,7 @@ class TestInversePermutation(unittest.TestCase):
assert numpy.all(inv_val[p_val] == numpy.arange(10))
def test_dim2(self):
"""Test the inversion of several permutation at a time"""
"""Test the inversion of several permutations at a time"""
# Each row of p is a different permutation to inverse
p = imatrix()
inv = inverse_permutation(p)
......@@ -2030,6 +2030,56 @@ class TestPermuteRowElements(unittest.TestCase):
return permute_row_elements(s_input, p_val)
utt.verify_grad(permute_fixed, [input_val])
def test_1_2(self):
"""Test PermuteRowElements(vector, matrix)
Different permutations will be applied to the same input vector"""
input = vector()
p = imatrix()
out = permute_row_elements(input, p)
permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(5,))
p_val = numpy.asarray([rng.permutation(5) for i in range(3)])
out_val = permute(input_val, p_val)
# Each row of p contains a permutation to apply to the input vector
out_bis = numpy.asarray([input_val[p_row] for p_row in p_val])
assert numpy.all(out_val == out_bis)
# Verify gradient
def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val"""
return permute_row_elements(s_input, p_val)
utt.verify_grad(permute_fixed, [input_val])
def test_3b_2(self):
"""Test permute_row_elements on a more complex broadcasting pattern:
input.type.broadcastable = (False, True, False),
p.type.broadcastable = (False, False)."""
input = TensorType('float64', (False, True, False))()
p = imatrix()
out = permute_row_elements(input, p)
permute = function([input, p], out)
rng = numpy.random.RandomState(utt.fetch_seed())
input_val = rng.uniform(size=(4,1,5))
p_val = numpy.asarray([rng.permutation(5) for i in range(3)])
out_val = permute(input_val, p_val)
# Each row of p contains a permutation to apply to each row
# of the input tensor
out_bis = numpy.asarray([[in_mat[0,p_row] for p_row in p_val] for in_mat in input_val])
assert numpy.all(out_val == out_bis)
# Verify gradient
def permute_fixed(s_input):
"""Auxiliary op defined to get rid of gradient wrt p_val"""
return permute_row_elements(s_input, p_val)
utt.verify_grad(permute_fixed, [input_val])
class test_tensordot(unittest.TestCase):
def setUp(self):
utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论