提交 2feb66f4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

More documentation. Add the ability to apply the inverse permutation instead.

上级 8e9f3ff6
......@@ -2707,7 +2707,7 @@ def tile(x, reps, ndim=None):
tile.op = {}
if ndim is None:
ndim = len(reps)
#backport
#ndim = len(reps) if ndim is None else ndim #not sure if len(shp) is going to work.
if ndim not in tile.op:
......@@ -2738,42 +2738,80 @@ inverse_permutation = InversePermutation()
class PermuteRowElements(Op):
"""Permute the elements of each row (inner-most dim) of a tensor.
The permutation argument (y) will be broadcasted to fit x, then each
row (vector) of x will be reordered according to the corresponding row
of y.
A permutation will be applied to every row (vector) of the input tensor x.
Depending on the dimensionality of x and the permutation tensor y,
different cases are possible.
If y.ndim = 1, y is a single permutation, that will be applied to every
vector of x. For instance, if x is a matrix, the same permutation will be
applied to each row of x.
If x.ndim = y.ndim, each row of x corresponds to a row of y, containing
a permutation that will be applied to that row. For instance, if x and y
are two matrices, a different permutation will be applied to each row of x.
If x.ndim > y.ndim, y will be broadcasted to fit x, then each row (vector)
of x will be reordered according to the corresponding row of y. (This is
a generalization of the first case).
WARNING: x will not be broadcasted to fit y (not implemented yet).
If the "inverse" argument is True, the Op will perform the inverse
permutation instead.
"""
def make_node(self, x, y):
def make_node(self, x, y, inverse):
x = as_tensor_variable(x)
y = as_tensor_variable(y)
inverse = as_tensor_variable(inverse)
print 'in make_node: inverse =', inverse
# y should contain integers
assert y.type.dtype.startswith('int') or y.type.dtype.startswith('uint')
# Inverse should be an integer scalar
assert inverse.type.ndim == 0 and\
(inverse.type.dtype.startswith('int') or\
inverse.type.dtype.startswith('uint'))
# extend y dimension to match x
assert x.type.ndim >= y.type.ndim
y = shape_padleft(y, n_ones=(x.type.ndim - y.type.ndim))
inputlist = [x, y]
inputlist = [x, y, inverse]
outputlist = [x.type()]
return Apply(self, inputlist, outputlist)
def _rec_perform(self, node, x, y, out, curdim):
def _rec_perform(self, node, x, y, inverse, out, curdim):
"""Perform the permutation by doing a recursion over the input dimensions.
For every dimension, starting with the leftmost, the right set of
indices is determined (depending if broadcasting or not), then
the function is recursively called on the appropriate subtensors.
The terminal case is reached when the current tensors are vector,
then the permutation contained in y is applied to x.
:param x: The input tensor, on which the permutation is applied
:param y: Tensor containing the permutations to apply
:param out: Tensor storing the output result
:param curdim: Counter of the current depth of recursion
:param inverse: Wether to apply permutations or their inverse
"""
if len(x.shape) == 1:
# Numpy advanced indexing works in this case
out[:] = x[y]
if inverse:
out[y] = x[:]
else:
out[:] = x[y]
else:
xs0 = x.shape[0]
ys0 = y.shape[0]
if xs0 == ys0:
for i in range(xs0):
self._rec_perform(node, x[i], y[i], out[i], curdim+1)
self._rec_perform(node, x[i], y[i], inverse, out[i], curdim+1)
elif node.inputs[1].type.broadcastable[curdim]:
# Broadcast y
for i in range(xs0):
self._rec_perform(node, x[i], y[0], out[i], curdim+1)
self._rec_perform(node, x[i], y[0], inverse, out[i], curdim+1)
else:
raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0))
def perform(self, node, (x, y), (outs,)):
def perform(self, node, (x, y, inverse), (outs,)):
x_s = x.shape
y_s = y.shape
assert len(x_s) == len(y_s)
......@@ -2781,13 +2819,17 @@ class PermuteRowElements(Op):
if outs[0] is None or outs[0].shape != x_s:
outs[0] = numpy.empty_like(x)
self._rec_perform(node, x, y, outs[0], curdim=0)
self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
def grad(self, (x, y), (gz,)):
gx = permute_row_elements(gz, inverse_permutation(y))
return [gx, None]
def grad(self, (x, y, inverse), (gz,)):
"""If 'inverse' is False (0), apply the inverse of y on gz.
Else, apply y on gz."""
gx = permute_row_elements(gz, y, eq(inverse, 0))
return [gx, None, None]
permute_row_elements = PermuteRowElements()
_permute_row_elements = PermuteRowElements()
def permute_row_elements(x, y, inverse=0):
return _permute_row_elements(x, y, inverse)
#########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论