Unverified 提交 52bbf59d authored 作者: Abhinav's avatar Abhinav 提交者: GitHub

Modify atleast_Nd to accept only one positional argument (#1291)

上级 f10a6036
...@@ -4355,28 +4355,22 @@ def empty_like( ...@@ -4355,28 +4355,22 @@ def empty_like(
def atleast_Nd( def atleast_Nd(
*arys: np.ndarray | TensorVariable, n: int = 1, left: bool = True arry: np.ndarray | TensorVariable, *, n: int = 1, left: bool = True
) -> TensorVariable: ) -> TensorVariable:
"""Convert inputs to arrays with at least `n` dimensions.""" """Convert input to an array with at least `n` dimensions."""
res = []
for ary in arys:
ary = as_tensor(ary)
if ary.ndim >= n: arry = as_tensor(arry)
result = ary
if arry.ndim >= n:
result = arry
else: else:
result = ( result = (
shape_padleft(ary, n - ary.ndim) shape_padleft(arry, n - arry.ndim)
if left if left
else shape_padright(ary, n - ary.ndim) else shape_padright(arry, n - arry.ndim)
) )
res.append(result) return result
if len(res) == 1:
return res[0]
else:
return res
atleast_1d = partial(atleast_Nd, n=1) atleast_1d = partial(atleast_Nd, n=1)
......
...@@ -4364,7 +4364,8 @@ def test_atleast_Nd(): ...@@ -4364,7 +4364,8 @@ def test_atleast_Nd():
for n in range(1, 3): for n in range(1, 3):
ary1, ary2 = dscalar(), dvector() ary1, ary2 = dscalar(), dvector()
res_ary1, res_ary2 = atleast_Nd(ary1, ary2, n=n) res_ary1 = atleast_Nd(ary1, n=n)
res_ary2 = atleast_Nd(ary2, n=n)
assert res_ary1.ndim == n assert res_ary1.ndim == n
if n == ary2.ndim: if n == ary2.ndim:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论