提交 45a2f784 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make `take` return a Subtensor

上级 0dbd512b
...@@ -2717,41 +2717,25 @@ def take(a, indices, axis=None, mode="raise"): ...@@ -2717,41 +2717,25 @@ def take(a, indices, axis=None, mode="raise"):
""" """
a = aesara.tensor.as_tensor_variable(a) a = aesara.tensor.as_tensor_variable(a)
indices = aesara.tensor.as_tensor_variable(indices) indices = aesara.tensor.as_tensor_variable(indices)
# Reuse advanced_subtensor1 if indices is a vector
if indices.ndim == 1: if not isinstance(axis, (int, type(None))):
if mode == "clip": raise TypeError("`axis` must be an integer or None")
indices = clip(indices, 0, a.shape[axis] - 1)
elif mode == "wrap": if axis is None and indices.ndim == 1:
indices = indices % a.shape[axis] return advanced_subtensor1(a.flatten(), indices)
if axis is None: elif axis == 0 and indices.ndim == 1:
return advanced_subtensor1(a.flatten(), indices) return advanced_subtensor1(a, indices)
elif axis == 0: elif axis < 0:
return advanced_subtensor1(a, indices) axis += a.ndim
else:
if axis < 0: if mode == "clip":
axis += a.ndim indices = clip(indices, 0, a.shape[axis] - 1)
assert axis >= 0 elif mode == "wrap":
shuffle = list(range(a.ndim)) indices = indices % a.shape[axis]
shuffle[0] = axis
shuffle[axis] = 0 full_indices = (slice(None),) * axis + (indices,)
return advanced_subtensor1(a.dimshuffle(shuffle), indices).dimshuffle(
shuffle return a[full_indices]
)
if axis is None:
shape = indices.shape
ndim = indices.ndim
else:
# If axis is 0, don't generate a useless concatenation.
if axis == 0:
shape = aesara.tensor.concatenate([indices.shape, a.shape[axis + 1 :]])
else:
if axis < 0:
axis += a.ndim
shape = aesara.tensor.concatenate(
[a.shape[:axis], indices.shape, a.shape[axis + 1 :]]
)
ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
__all__ = [ __all__ = [
......
...@@ -1449,7 +1449,7 @@ def test_take_cases(a, index, axis, mode): ...@@ -1449,7 +1449,7 @@ def test_take_cases(a, index, axis, mode):
fn_mode = aesara.compile.mode.get_default_mode() fn_mode = aesara.compile.mode.get_default_mode()
fn_mode = fn_mode.including( fn_mode = fn_mode.including(
"local_useless_subtensor", "local_useless_subtensor",
# "local_replace_AdvancedSubtensor", "local_replace_AdvancedSubtensor",
) )
f = aesara.function([a, index], a.take(index, axis=axis, mode=mode), mode=fn_mode) f = aesara.function([a, index], a.take(index, axis=axis, mode=mode), mode=fn_mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论