提交 c107dc41 authored 作者: abalkin's avatar abalkin

Added support for advanced indexing that reduces to x.take().

上级 aa7a7fdc
...@@ -1662,21 +1662,29 @@ class _tensor_py_operators: ...@@ -1662,21 +1662,29 @@ class _tensor_py_operators:
# standard indexing is used; if it fails with # standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing # AdvancedIndexingError, advanced indexing
advanced = False advanced = False
for arg in args: axis = None
for i, arg in enumerate(args):
try: try:
arg == numpy.newaxis or Subtensor.convert(arg) arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError: except AdvancedIndexingError:
advanced = True if advanced:
axis = None
break break
else:
advanced = True
axis = i
if advanced: if advanced:
if (len(args) == 1 if (axis is not None
and isinstance(args[0], ( and numpy.all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis+1:])
and isinstance(args[axis], (
numpy.ndarray,
list, list,
TensorVariable, TensorVariable,
TensorConstant, TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))): theano.tensor.sharedvar.TensorSharedVariable))):
return advanced_subtensor1(self, *args) return self.take(arg, axis)
else: else:
return AdvancedSubtensor()(self, *args) return AdvancedSubtensor()(self, *args)
else: else:
......
...@@ -7185,6 +7185,8 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7185,6 +7185,8 @@ class TestTensorInstanceMethods(unittest.TestCase):
x.take(indices, -1, mode='clip')) x.take(indices, -1, mode='clip'))
indices = [[1,0,1], [0,1,1]] indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:,indices].eval({X: x}), x[:,indices])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论