提交 5d97ffa6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement moveaxis

上级 0d698099
...@@ -10,11 +10,14 @@ import warnings ...@@ -10,11 +10,14 @@ import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import Optional, Tuple, Union from typing import Optional
from typing import Sequence as TypeSequence
from typing import Tuple, Union
from typing import cast as type_cast from typing import cast as type_cast
import numpy as np import numpy as np
from numpy.core.multiarray import normalize_axis_index from numpy.core.multiarray import normalize_axis_index
from numpy.core.numeric import normalize_axis_tuple
import aesara import aesara
import aesara.scalar.sharedvar import aesara.scalar.sharedvar
...@@ -3635,6 +3638,51 @@ def swapaxes(y, axis1, axis2): ...@@ -3635,6 +3638,51 @@ def swapaxes(y, axis1, axis2):
return y.dimshuffle(li) return y.dimshuffle(li)
def moveaxis(
a: Union[np.ndarray, TensorVariable],
source: Union[int, TypeSequence[int]],
destination: Union[int, TypeSequence[int]],
) -> TensorVariable:
"""Move axes of a TensorVariable to new positions.
Other axes remain in their original order.
Parameters
----------
a
The TensorVariable whose axes should be reordered.
source
Original positions of the axes to move. These must be unique.
destination
Destination positions for each of the original axes. These must also be
unique.
Returns
-------
result
TensorVariable with moved axes.
"""
a = as_tensor_variable(a)
source = normalize_axis_tuple(source, a.ndim, "source")
destination = normalize_axis_tuple(destination, a.ndim, "destination")
if len(source) != len(destination):
raise ValueError(
"`source` and `destination` arguments must have the same number of elements"
)
order = [n for n in range(a.ndim) if n not in source]
for dest, src in sorted(zip(destination, source)):
order.insert(dest, src)
result = a.dimshuffle(order)
return result
def choose(a, choices, mode="raise"): def choose(a, choices, mode="raise"):
""" """
Construct an array from an index array and a set of arrays to choose from. Construct an array from an index array and a set of arrays to choose from.
...@@ -4014,6 +4062,7 @@ __all__ = [ ...@@ -4014,6 +4062,7 @@ __all__ = [
"atleast_3d", "atleast_3d",
"choose", "choose",
"swapaxes", "swapaxes",
"moveaxis",
"stacklists", "stacklists",
"diag", "diag",
"diagonal", "diagonal",
......
...@@ -60,6 +60,7 @@ from aesara.tensor.basic import ( ...@@ -60,6 +60,7 @@ from aesara.tensor.basic import (
join, join,
make_vector, make_vector,
mgrid, mgrid,
moveaxis,
nonzero, nonzero,
nonzero_values, nonzero_values,
ogrid, ogrid,
...@@ -3984,6 +3985,23 @@ class TestSwapaxes: ...@@ -3984,6 +3985,23 @@ class TestSwapaxes:
assert np.allclose(n_s, t_s) assert np.allclose(n_s, t_s)
def test_moveaxis():
x = at.zeros((3, 4, 5))
tuple(moveaxis(x, 0, -1).shape.eval()) == (4, 5, 3)
tuple(moveaxis(x, -1, 0).shape.eval()) == (5, 3, 4)
tuple(moveaxis(x, [0, 1], [-1, -2]).shape.eval()) == (5, 4, 3)
tuple(moveaxis(x, [0, 1, 2], [-1, -2, -3]).shape.eval()) == (5, 4, 3)
def test_moveaxis_error():
x = at.zeros((3, 4, 5))
with pytest.raises(
ValueError,
match="`source` and `destination` arguments must have the same number of elements",
):
moveaxis(x, [0, 1], 0)
class TestChoose(utt.InferShapeTester): class TestChoose(utt.InferShapeTester):
op = staticmethod(choose) op = staticmethod(choose)
op_class = Choose op_class = Choose
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论