提交 f50d8f59 authored 作者: abalkin's avatar abalkin

Reuse advanced indexing in supported cases.

上级 8ea11b10
......@@ -6848,6 +6848,17 @@ class Take(Op):
def take(a, indices, axis=None, mode='raise'):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
# Reuse advanced indexing in supported cases.
if axis is None:
if indices.ndim == 1:
return a.flatten[indices]
else:
if indices.ndim == 0:
item = [slice(None)] * a.ndim
item[axis] = indices
return a[tuple(item)]
return Take(axis, mode)(a, indices)
#########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论