提交 60bc3688 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use Dimshuffle for expand_dims

上级 652d0b69
...@@ -4013,10 +4013,10 @@ def expand_dims( ...@@ -4013,10 +4013,10 @@ def expand_dims(
out_ndim = len(axis) + a.ndim out_ndim = len(axis) + a.ndim
axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim) axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim)
shape_it = iter(a.shape) dim_it = iter(range(a.ndim))
shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] pattern = ["x" if ax in axis else next(dim_it) for ax in range(out_ndim)]
return a.reshape(shape) return a.dimshuffle(pattern)
def _make_along_axis_idx(arr_shape, indices, axis): def _make_along_axis_idx(arr_shape, indices, axis):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论