提交 0daa4491 authored 作者: abalkin's avatar abalkin

Added .take() member function, added test, fixed axis=None case.

上级 f50d8f59
...@@ -1705,6 +1705,9 @@ class _tensor_py_operators: ...@@ -1705,6 +1705,9 @@ class _tensor_py_operators:
return Subtensor(args)(self, *Subtensor.collapse(args, return Subtensor(args)(self, *Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable))) lambda entry: isinstance(entry, Variable)))
def take(self, indices, axis=None, mode='raise'):
return take(self, indices, axis, mode)
# COPYING # COPYING
def copy(self): def copy(self):
return tensor_copy(self) return tensor_copy(self)
...@@ -6853,7 +6856,7 @@ def take(a, indices, axis=None, mode='raise'): ...@@ -6853,7 +6856,7 @@ def take(a, indices, axis=None, mode='raise'):
# Reuse advanced indexing in supported cases. # Reuse advanced indexing in supported cases.
if axis is None: if axis is None:
if indices.ndim == 1: if indices.ndim == 1:
return a.flatten[indices] return a.flatten()[indices]
else: else:
if indices.ndim == 0: if indices.ndim == 0:
item = [slice(None)] * a.ndim item = [slice(None)] * a.ndim
......
...@@ -7167,6 +7167,10 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7167,6 +7167,10 @@ class TestTensorInstanceMethods(unittest.TestCase):
assert_array_equal(X.diagonal(offset, axis1, axis2).eval({X: x}), assert_array_equal(X.diagonal(offset, axis1, axis2).eval({X: x}),
x.diagonal(offset, axis1, axis2)) x.diagonal(offset, axis1, axis2))
def test_take(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.take([1,0,3]).eval({X: x}), x.take([1,0,3]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论