提交 1eafa0b0 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed gradient of argmax when axis is None

上级 92ec8cb2
......@@ -1964,8 +1964,11 @@ class MaxAndArgmax(Op):
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
for i in range(inp[0].ndim):
if i == axis.data:
if python_all(axis.data == range(x.ndim)):
# We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i == axis.data:
pattern.append('x')
else:
pattern.append(out_dim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论