Unverified 提交 58aad80a authored 作者: Pham Nguyen Hung's avatar Pham Nguyen Hung 提交者: GitHub

Avoid dimshuffle if expand_dims has empty axis (#724)

上级 86bc1d29
......@@ -4247,6 +4247,17 @@ def expand_dims(
Insert a new axis that will appear at the `axis` position in the expanded
array shape.
Parameters
----------
a :
The input array.
axis :
Position in the expanded axes where the new axis is placed.
If `axis` is empty, `a` will be returned immediately.
Returns
-------
`a` with a new axis at the `axis` position.
"""
a = as_tensor(a)
......@@ -4256,6 +4267,9 @@ def expand_dims(
out_ndim = len(axis) + a.ndim
axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim)
if not axis:
return a
dim_it = iter(range(a.ndim))
pattern = ["x" if ax in axis else next(dim_it) for ax in range(out_ndim)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论