提交 22f463c1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add take_along_axis function

上级 69dc7d15
...@@ -15,6 +15,7 @@ from numbers import Number ...@@ -15,6 +15,7 @@ from numbers import Number
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
import numpy as np import numpy as np
from numpy.core.multiarray import normalize_axis_index
import aesara import aesara
import aesara.scalar.sharedvar import aesara.scalar.sharedvar
...@@ -4347,7 +4348,55 @@ def expand_dims( ...@@ -4347,7 +4348,55 @@ def expand_dims(
return a.reshape(shape) return a.reshape(shape)
def _make_along_axis_idx(arr_shape, indices, axis):
"""Take from `numpy.lib.shape_base`."""
# compute dimensions to iterate over
if str(indices.dtype) not in int_dtypes:
raise IndexError("`indices` must be an integer array")
shape_ones = (1,) * indices.ndim
dest_dims = list(range(axis)) + [None] + list(range(axis + 1, indices.ndim))
# build a fancy index, consisting of orthogonal aranges, with the
# requested index inserted at the right location
fancy_index = []
for dim, n in zip(dest_dims, arr_shape):
if dim is None:
fancy_index.append(indices)
else:
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
fancy_index.append(arange(n).reshape(ind_shape))
return tuple(fancy_index)
def take_along_axis(arr, indices, axis=0):
"""Take values from the input array by matching 1d index and data slices.
This iterates over matching 1d slices oriented along the specified axis in
the index and data arrays, and uses the former to look up values in the
latter. These slices can be different lengths.
Functions returning an index along an axis, like `argsort` and
`argpartition`, produce suitable indices for this function.
"""
arr = as_tensor_variable(arr)
indices = as_tensor_variable(indices)
# normalize inputs
if axis is None:
arr = arr.flatten()
axis = 0
else:
axis = normalize_axis_index(axis, arr.ndim)
if arr.ndim != indices.ndim:
raise ValueError("`indices` and `arr` must have the same number of dimensions")
# use the fancy index
return arr[_make_along_axis_idx(arr.shape, indices, axis)]
__all__ = [ __all__ = [
"take_along_axis",
"expand_dims", "expand_dims",
"atleast_Nd", "atleast_Nd",
"atleast_1d", "atleast_1d",
......
...@@ -4175,3 +4175,44 @@ def test_expand_dims(): ...@@ -4175,3 +4175,44 @@ def test_expand_dims():
exp_res = np.expand_dims(x_val, (2, 1)) exp_res = np.expand_dims(x_val, (2, 1))
res_val = aesara.function([x_at], res_at)(x_val) res_val = aesara.function([x_at], res_at)(x_val)
assert np.array_equal(exp_res, res_val) assert np.array_equal(exp_res, res_val)
class TestTakeAlongAxis:
@pytest.mark.parametrize(
["shape", "axis", "samples"],
(
((1,), None, 1),
((1,), -1, 10),
((3, 2, 1), -1, 1),
((3, 2, 1), 0, 10),
((3, 2, 1), -1, 10),
),
ids=str,
)
def test_take_along_axis(self, shape, axis, samples):
rng = np.random.default_rng()
arr = rng.normal(size=shape).astype(config.floatX)
indices_size = list(shape)
indices_size[axis or 0] = samples
indices = rng.integers(low=0, high=shape[axis or 0], size=indices_size)
arr_in = aet.tensor(config.floatX, [s == 1 for s in arr.shape])
indices_in = aet.tensor(np.int64, [s == 1 for s in indices.shape])
out = aet.take_along_axis(arr_in, indices_in, axis)
func = aesara.function([arr_in, indices_in], out)
assert np.allclose(
np.take_along_axis(arr, indices, axis=axis), func(arr, indices)
)
def test_ndim_dtype_failures(self):
arr = aet.tensor(config.floatX, [False] * 2)
indices = aet.tensor(np.int64, [False] * 3)
with pytest.raises(ValueError):
aet.take_along_axis(arr, indices)
indices = aet.tensor(np.float64, [False] * 2)
with pytest.raises(IndexError):
aet.take_along_axis(arr, indices)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论