提交 584754da authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add two ops:

InversePermutation ReorderRowElements
上级 2d6f8566
...@@ -2692,6 +2692,78 @@ def tile(x, reps, ndim=None): ...@@ -2692,6 +2692,78 @@ def tile(x, reps, ndim=None):
tile.op[ndim] = Tile(ndim) tile.op[ndim] = Tile(ndim)
return tile.op[ndim](x, reps) return tile.op[ndim](x, reps)
class InversePermutation(Op):
"""Computes the inverse of permutations.
Each row of input should contain a permutation of the first integers.
"""
def make_node(self, x):
return Apply(self, [x], [x.type()])
def perform(self, node, (x,), (outs,)):
if outs[0] is None or outs[0].shape != x.shape:
outs[0] = numpy.empty_like(x)
for i in numpy.ndindex(x.shape[:-1]):
outs[0][i][x[i]] = numpy.arange(x.shape[0])
def grad(self, (x,), (gz,)):
return [None]
inverse_permutation = InversePermutation()
class ReorderRowElements(Op):
"""Reorder each row (inner-most dim) of a tensor wrt a permutation.
The permutation argument (y) will be broadcasted to fit x, then each
row (vector) of x will be reordered according to the corresponding row
of y.
WARNING: x will not be broadcasted to fit y (not implemented yet).
"""
def make_node(self, x, y):
assert x.type.ndim >= y.type.ndim
x = as_tensor_variable(x)
# extend y dimension to match x
assert y.type.dtype.startswith('int') or y.type.dtype.startswith('uint')
y = as_tensor_variable(y, ndim = x.type.ndim)
inputlist = [x, y]
outputlist = [x.type()]
return Apply(self, inputlist, outputlist)
def _rec_perform(self, node, x, y, out, curdim):
if len(x.shape) == 1:
# Numpy advanced indexing works in this case
out[:] = x[y]
else:
xs0 = x.shape[0]
ys0 = y.shape[0]
if xs0 == ys0:
for i in range(xs0):
self._rec_perform(node, x[i], y[i], out[i], curdim+1)
elif node.inputs[1].type.broadcastable[curdim]:
# Broadcast y
for i in range(xs0):
self._rec_perform(node, x[i], y[0], out[i], curdim+1)
else:
raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0))
def perform(self, node, (x, y), (outs,)):
x_s = x.shape
y_s = y.shape
assert len(x_s) == len(y_s)
if outs[0] is None or outs[0].shape != x_s:
outs[0] = numpy.empty_like(x)
self._rec_perform(node, x, y, outs[0], curdim=0)
def grad(self, (x, y), (gz,)):
gx = reorder_row_elements(gz, inverse_permutation(y))
return [gx, None]
reorder_row_elements = ReorderRowElements()
######################### #########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论