提交 11603832 authored 作者: abalkin's avatar abalkin

Reuse advanced_subtensor1 if indices is a vector.

上级 d6c185de
......@@ -6855,15 +6855,28 @@ class Take(Op):
shape = a_shape[:self.axis] + indices_shape + a_shape[self.axis+1:]
return [shape]
def take(a, indices, axis=None, mode='raise'):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
# Reuse advanced indexing in supported cases.
if axis is None and mode == 'raise':
if indices.ndim == 1:
return a.flatten()[indices]
# Reuse advanced_subtensor1 if indices is a vector
if indices.ndim == 1:
if mode == 'clip':
indices = clip(indices, 0, a.shape[axis]-1)
elif mode == 'wrap':
indices = indices % a.shape[axis]
if axis is None:
return advanced_subtensor1(a.flatten(), indices)
elif axis == 0:
return advanced_subtensor1(a, indices)
else:
if axis < 0:
axis += a.ndim
assert axis >= 0
shuffle = range(a.ndim)
shuffle[0] = axis
shuffle[axis] = 0
return advanced_subtensor1(
a.dimshuffle(shuffle), indices).dimshuffle(shuffle)
return Take(axis, mode)(a, indices)
#########################
......
......@@ -7174,6 +7174,15 @@ class TestTensorInstanceMethods(unittest.TestCase):
assert_array_equal(X.take(indices).eval({X: x}), x.take(indices))
indices = [1,0,1]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
indices = [-10,5,12]
assert_array_equal(X.take(indices, 1, mode='wrap').eval({X: x}),
x.take(indices, 1, mode='wrap'))
assert_array_equal(X.take(indices, -1, mode='wrap').eval({X: x}),
x.take(indices, -1, mode='wrap'))
assert_array_equal(X.take(indices, 1, mode='clip').eval({X: x}),
x.take(indices, 1, mode='clip'))
assert_array_equal(X.take(indices, -1, mode='clip').eval({X: x}),
x.take(indices, -1, mode='clip'))
indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论