提交 8a226a76 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify makeKeepdDims

上级 9ba6d99f
......@@ -297,32 +297,9 @@ def makeKeepDims(x, y, axis):
"""
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if axis is None:
axis = list(range(x.type.ndim))
elif isinstance(axis, int | np.integer):
axis = [axis]
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]
newaxis = []
for a in axis:
if not isinstance(a, int):
raise ValueError("keepdims option can be used only with constant axis")
if a < 0:
a += x.type.ndim
newaxis.append(a)
i = 0
new_dims = []
for j, _ in enumerate(x.type.broadcastable):
if j in newaxis:
new_dims.append("x")
else:
new_dims.append(i)
i += 1
return DimShuffle(y.type.broadcastable, new_dims)(y)
return expand_dims(y, axis)
def check_and_normalize_axes(x, axis):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论