提交 c5a25e58 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Revert "remove use of take in advanced subtensor"

This reverts commit baaa8111 and 701f8b54
上级 0de42ea0
......@@ -1752,16 +1752,29 @@ class _tensor_py_operators:
# standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing
advanced = False
for arg in args:
axis = None
for i, arg in enumerate(args):
try:
arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError:
advanced = True
break
if advanced:
axis = None
break
else:
advanced = True
axis = i
if advanced:
if (len(args) == 1 and as_tensor_variable(args[0]).ndim <= 1):
return advanced_subtensor1(self, *args)
if (axis is not None
and numpy.all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis + 1:])
and isinstance(args[axis], (
numpy.ndarray,
list,
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis)
else:
return AdvancedSubtensor()(self, *args)
else:
......
......@@ -7425,6 +7425,8 @@ class TestTensorInstanceMethods(unittest.TestCase):
self.assertRaises(TypeError, X.take, [0.0])
indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:,indices].eval({X: x}), x[:,indices])
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论