提交 baaa8111 authored 作者: John Salvatier's avatar John Salvatier

remove use of take in advanced subtensor

上级 6e4511f4
......@@ -1748,29 +1748,21 @@ class _tensor_py_operators:
# standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing
advanced = False
axis = None
for i, arg in enumerate(args):
for arg in args:
try:
arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError:
if advanced:
axis = None
break
else:
advanced = True
axis = i
advanced = True
break
if advanced:
if (axis is not None
and numpy.all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis + 1:])
and isinstance(args[axis], (
numpy.ndarray,
if (len(args) == 1
and isinstance(args[0], (
list,
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis)
return advanced_subtensor1(self, *args)
else:
return AdvancedSubtensor()(self, *args)
else:
......
......@@ -7410,8 +7410,6 @@ class TestTensorInstanceMethods(unittest.TestCase):
self.assertRaises(TypeError, X.take, [0.0])
indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:,indices].eval({X: x}), x[:,indices])
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论