提交 1fa8df43 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add atleast_Nd and expand_dims functions

上级 b889ec1e
...@@ -10,8 +10,9 @@ import logging ...@@ -10,8 +10,9 @@ import logging
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial
from numbers import Number from numbers import Number
from typing import Dict from typing import Dict, Tuple, Union
import numpy as np import numpy as np
...@@ -39,6 +40,7 @@ from aesara.tensor.shape import ( ...@@ -39,6 +40,7 @@ from aesara.tensor.shape import (
shape, shape,
shape_padaxis, shape_padaxis,
shape_padleft, shape_padleft,
shape_padright,
) )
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
...@@ -4290,7 +4292,64 @@ def empty(shape, dtype=None): ...@@ -4290,7 +4292,64 @@ def empty(shape, dtype=None):
return AllocEmpty(dtype)(*shape) return AllocEmpty(dtype)(*shape)
def atleast_Nd(
*arys: Union[np.ndarray, TensorVariable], n: int = 1, left: bool = True
) -> TensorVariable:
"""Convert inputs to arrays with at least `n` dimensions."""
res = []
for ary in arys:
ary = as_tensor(ary)
if ary.ndim >= n:
result = ary
else:
result = (
shape_padleft(ary, n - ary.ndim)
if left
else shape_padright(ary, n - ary.ndim)
)
res.append(result)
if len(res) == 1:
return res[0]
else:
return res
atleast_1d = partial(atleast_Nd, n=1)
atleast_2d = partial(atleast_Nd, n=2)
atleast_3d = partial(atleast_Nd, n=3)
def expand_dims(
a: Union[np.ndarray, TensorVariable], axis: Tuple[int, ...]
) -> TensorVariable:
"""Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
array shape.
"""
a = as_tensor(a)
if not isinstance(axis, (tuple, list)):
axis = (axis,)
out_ndim = len(axis) + a.ndim
axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim)
shape_it = iter(a.shape)
shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
return a.reshape(shape)
__all__ = [ __all__ = [
"expand_dims",
"atleast_Nd",
"atleast_1d",
"atleast_2d",
"atleast_3d",
"choose", "choose",
"swapaxes", "swapaxes",
"stacklists", "stacklists",
......
...@@ -43,11 +43,13 @@ from aesara.tensor.basic import ( ...@@ -43,11 +43,13 @@ from aesara.tensor.basic import (
alloc, alloc,
arange, arange,
as_tensor_variable, as_tensor_variable,
atleast_Nd,
cast, cast,
choose, choose,
constant, constant,
default, default,
diag, diag,
expand_dims,
extract_constant, extract_constant,
eye, eye,
fill, fill,
...@@ -4118,3 +4120,52 @@ def test_allocempty(): ...@@ -4118,3 +4120,52 @@ def test_allocempty():
res = aesara.function([], empty_at)() res = aesara.function([], empty_at)()
assert res.shape == (2, 3) assert res.shape == (2, 3)
assert res.dtype == "int64" assert res.dtype == "int64"
def test_atleast_Nd():
ary1 = dscalar()
res_ary1 = atleast_Nd(ary1, n=1)
assert res_ary1.ndim == 1
for n in range(1, 3):
ary1, ary2 = dscalar(), dvector()
res_ary1, res_ary2 = atleast_Nd(ary1, ary2, n=n)
assert res_ary1.ndim == n
if n == ary2.ndim:
assert ary2 is res_ary2
else:
assert res_ary2.ndim == n
ary1_val = np.array(1.0, dtype=np.float64)
ary2_val = np.array([1.0, 2.0], dtype=np.float64)
res_ary1_val, res_ary2_val = aesara.function(
[ary1, ary2], [res_ary1, res_ary2]
)(ary1_val, ary2_val)
np_fn = np.atleast_1d if n == 1 else np.atleast_2d
assert np.array_equal(res_ary1_val, np_fn(ary1_val))
assert np.array_equal(res_ary2_val, np_fn(ary2_val))
def test_expand_dims():
x_at = dscalar()
res_at = expand_dims(x_at, 0)
x_val = np.array(1.0, dtype=np.float64)
exp_res = np.expand_dims(x_val, 0)
res_val = aesara.function([x_at], res_at)(x_val)
assert np.array_equal(exp_res, res_val)
x_at = dscalar()
res_at = expand_dims(x_at, (0, 1))
x_val = np.array(1.0, dtype=np.float64)
exp_res = np.expand_dims(x_val, (0, 1))
res_val = aesara.function([x_at], res_at)(x_val)
assert np.array_equal(exp_res, res_val)
x_at = dmatrix()
res_at = expand_dims(x_at, (2, 1))
x_val = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64)
exp_res = np.expand_dims(x_val, (2, 1))
res_val = aesara.function([x_at], res_at)(x_val)
assert np.array_equal(exp_res, res_val)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论