提交 baaa8111 authored 作者: John Salvatier's avatar John Salvatier

remove use of take in advanced subtensor

上级 6e4511f4
...@@ -1748,29 +1748,21 @@ class _tensor_py_operators: ...@@ -1748,29 +1748,21 @@ class _tensor_py_operators:
# standard indexing is used; if it fails with # standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing # AdvancedIndexingError, advanced indexing
advanced = False advanced = False
axis = None for arg in args:
for i, arg in enumerate(args):
try: try:
arg == numpy.newaxis or Subtensor.convert(arg) arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError: except AdvancedIndexingError:
if advanced: advanced = True
axis = None break
break
else:
advanced = True
axis = i
if advanced: if advanced:
if (axis is not None if (len(args) == 1
and numpy.all(a == slice(None) for a in args[:axis]) and isinstance(args[0], (
and numpy.all(a == slice(None) for a in args[axis + 1:])
and isinstance(args[axis], (
numpy.ndarray,
list, list,
TensorVariable, TensorVariable,
TensorConstant, TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))): theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis) return advanced_subtensor1(self, *args)
else: else:
return AdvancedSubtensor()(self, *args) return AdvancedSubtensor()(self, *args)
else: else:
......
...@@ -7410,8 +7410,6 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7410,8 +7410,6 @@ class TestTensorInstanceMethods(unittest.TestCase):
self.assertRaises(TypeError, X.take, [0.0]) self.assertRaises(TypeError, X.take, [0.0])
indices = [[1,0,1], [0,1,1]] indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:,indices].eval({X: x}), x[:,indices])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论