提交 bc89ee83 authored 作者: lamblin's avatar lamblin

Merge pull request #1127 from abalkin/take-op

Take op [WIP]
...@@ -1662,21 +1662,29 @@ class _tensor_py_operators: ...@@ -1662,21 +1662,29 @@ class _tensor_py_operators:
# standard indexing is used; if it fails with # standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing # AdvancedIndexingError, advanced indexing
advanced = False advanced = False
for arg in args: axis = None
for i, arg in enumerate(args):
try: try:
arg == numpy.newaxis or Subtensor.convert(arg) arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError: except AdvancedIndexingError:
advanced = True if advanced:
break axis = None
break
else:
advanced = True
axis = i
if advanced: if advanced:
if (len(args) == 1 if (axis is not None
and isinstance(args[0], ( and numpy.all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis+1:])
and isinstance(args[axis], (
numpy.ndarray,
list, list,
TensorVariable, TensorVariable,
TensorConstant, TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))): theano.tensor.sharedvar.TensorSharedVariable))):
return advanced_subtensor1(self, *args) return self.take(arg, axis)
else: else:
return AdvancedSubtensor()(self, *args) return AdvancedSubtensor()(self, *args)
else: else:
...@@ -1705,6 +1713,9 @@ class _tensor_py_operators: ...@@ -1705,6 +1713,9 @@ class _tensor_py_operators:
return Subtensor(args)(self, *Subtensor.collapse(args, return Subtensor(args)(self, *Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable))) lambda entry: isinstance(entry, Variable)))
def take(self, indices, axis=None, mode='raise'):
return take(self, indices, axis, mode)
# COPYING # COPYING
def copy(self): def copy(self):
return tensor_copy(self) return tensor_copy(self)
...@@ -6811,6 +6822,36 @@ class AdvancedIncSubtensor(Op): ...@@ -6811,6 +6822,36 @@ class AdvancedIncSubtensor(Op):
*inputs[2:]).outputs *inputs[2:]).outputs
advanced_inc_subtensor = AdvancedIncSubtensor() advanced_inc_subtensor = AdvancedIncSubtensor()
def take(a, indices, axis=None, mode='raise'):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
# Reuse advanced_subtensor1 if indices is a vector
if indices.ndim == 1:
if mode == 'clip':
indices = clip(indices, 0, a.shape[axis]-1)
elif mode == 'wrap':
indices = indices % a.shape[axis]
if axis is None:
return advanced_subtensor1(a.flatten(), indices)
elif axis == 0:
return advanced_subtensor1(a, indices)
else:
if axis < 0:
axis += a.ndim
assert axis >= 0
shuffle = range(a.ndim)
shuffle[0] = axis
shuffle[axis] = 0
return advanced_subtensor1(
a.dimshuffle(shuffle), indices).dimshuffle(shuffle)
if axis is None:
shape = indices.shape
ndim = indices.ndim
else:
shape = concatenate([a.shape[:axis], indices.shape, a.shape[axis+1:]])
ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
......
...@@ -7167,6 +7167,26 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7167,6 +7167,26 @@ class TestTensorInstanceMethods(unittest.TestCase):
assert_array_equal(X.diagonal(offset, axis1, axis2).eval({X: x}), assert_array_equal(X.diagonal(offset, axis1, axis2).eval({X: x}),
x.diagonal(offset, axis1, axis2)) x.diagonal(offset, axis1, axis2))
def test_take(self):
X, _ = self.vars
x, _ = self.vals
indices = [1,0,3]
assert_array_equal(X.take(indices).eval({X: x}), x.take(indices))
indices = [1,0,1]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
indices = [-10,5,12]
assert_array_equal(X.take(indices, 1, mode='wrap').eval({X: x}),
x.take(indices, 1, mode='wrap'))
assert_array_equal(X.take(indices, -1, mode='wrap').eval({X: x}),
x.take(indices, -1, mode='wrap'))
assert_array_equal(X.take(indices, 1, mode='clip').eval({X: x}),
x.take(indices, 1, mode='clip'))
assert_array_equal(X.take(indices, -1, mode='clip').eval({X: x}),
x.take(indices, -1, mode='clip'))
indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:,indices].eval({X: x}), x[:,indices])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论