提交 3d873c2a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

First draft of advanced indexing

上级 92194ff9
......@@ -825,10 +825,25 @@ class _tensor_py_operators:
def __getitem__(self, args):
if not isinstance(args, tuple):
args = args,
return Subtensor(args)(self, *Subtensor.collapse(args, lambda entry: isinstance(entry, Variable)))
# Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used, else, advanced indexing
advanced = False
for arg in args:
try:
Subtensor.convert(arg)
except TypeError:
advanced = True
break
if advanced:
return AdvancedSubtensor(args)(self, *args)
else:
return Subtensor(args)(self, *Subtensor.collapse(args, lambda entry: isinstance(entry, Variable)))
def __getslice__(self, *args):
args = slice(*args),
return Subtensor(args)(self, *Subtensor.collapse(args, lambda entry: isinstance(entry, Variable)))
return self.__getitem__(args)
#COPYING
def copy(self):
......@@ -2908,6 +2923,82 @@ def inverse_permutation(perm):
"""
return permute_row_elements(arange(perm.shape[-1]), perm, inverse=True)
#########################
# Advanced indexing
#########################
#
# Should reproduce numpy's behaviour:
# http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing.
"""
# Should be used by __getitem__ and __getslice__, as follow:
# AdvancedSubtensor(args)(self, *args),
# if args contains and advanced indexing pattern
def __init__(self, args): #idx_list?
# For the moment, __init__ will be passed the whole list of arguments
#TODO: see what's the best solution
self.args = args #?
#FIXME
if len(args) != 2:
print >>sys.stderr, 'WARNING: Advanced indexing with %i arguments not supported yet' % len(args)
print >>sys.stderr, ' arguments are:', args
def make_node(self, x, *inputs):
x = as_tensor_variable(x)
#FIXME
if x.ndim == 2 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0])
ind2 = as_tensor_variable(inputs[1])
if ind1.ndim == 1 and ind2.ndim == 1:
return gof.Apply(self,
(x,) + inputs,
[tensor(dtype = x.type.dtype,
broadcastable = [False])])
raise NotImplementedError('Advanced indexing of x (of dimension %i) with these argument dimensions (%s) not supported yet'\
% (x.ndim, ','.join(str(input.ndim) for input in inputs)))
def perform(self, node, inputs, (out,)):
pass
def grad(self, inputs, (gz,)):
x = inputs[0]
rest = inputs[1:]
return [AdvancedIncSubtensor(self.args)(zeros_like(x), gz, *rest)] + [None]*len(rest)
class AdvancedIncSubtensor(Op):
"""Increments a subtensor using advanced indexing.
"""
def __init__(self, args): #idx_list? inplace=False?
self.args = args
def make_node(self, x, y, *inputs):
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0])
ind2 = as_tensor_variable(inputs[1])
if ind1.ndim == 1 and ind2.ndim == 1:
return gof.Apply(self,
(x, y) + inputs,
[tensor(dtype = x.type.dtype,
broadcastable = x.type.broadcastable)])
raise NotImplementedError('Advanced indexing increment of x (of dimension %i) by y (of dimension %i) with these argument dimensions (%s) not supported yet'\
% (x.ndim, y.ndim, ','.join(str(input.ndim) for input in inputs)))
def perform(self, node, inputs, (out,)):
pass
#def grad?
#########################
# Linalg : Dot
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论