提交 0daa4491 authored 作者: abalkin's avatar abalkin

Added .take() member function, added test, fixed axis=None case.

上级 f50d8f59
......@@ -1705,6 +1705,9 @@ class _tensor_py_operators:
return Subtensor(args)(self, *Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable)))
def take(self, indices, axis=None, mode='raise'):
return take(self, indices, axis, mode)
# COPYING
def copy(self):
return tensor_copy(self)
......@@ -6853,7 +6856,7 @@ def take(a, indices, axis=None, mode='raise'):
# Reuse advanced indexing in supported cases.
if axis is None:
if indices.ndim == 1:
return a.flatten[indices]
return a.flatten()[indices]
else:
if indices.ndim == 0:
item = [slice(None)] * a.ndim
......
......@@ -7167,7 +7167,11 @@ class TestTensorInstanceMethods(unittest.TestCase):
assert_array_equal(X.diagonal(offset, axis1, axis2).eval({X: x}),
x.diagonal(offset, axis1, axis2))
def test_take(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.take([1,0,3]).eval({X: x}), x.take([1,0,3]))
if __name__ == '__main__':
t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论