提交 1eafa0b0 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed gradient of argmax when axis is None

上级 92ec8cb2
...@@ -1964,8 +1964,11 @@ class MaxAndArgmax(Op): ...@@ -1964,8 +1964,11 @@ class MaxAndArgmax(Op):
# Raise the g_max and xmax to the same number of dim as the input. # Raise the g_max and xmax to the same number of dim as the input.
pattern = [] pattern = []
out_dim = 0 out_dim = 0
for i in range(inp[0].ndim): if python_all(axis.data == range(x.ndim)):
if i == axis.data: # We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i == axis.data:
pattern.append('x') pattern.append('x')
else: else:
pattern.append(out_dim) pattern.append(out_dim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论