提交 6d1813d9 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

sketch of grad for max

上级 85907fbb
......@@ -519,6 +519,10 @@ class MaxAndArgmax(Op):
def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.max(x, axis)
max_idx[0] = numpy.argmax(x, axis)
# def grad(self, (x, axis), (g_max, g_max_idx)):
# # This only works if axis is 0, else the max is broadcasted wrong in the call to eq
# g_x = eq(max(x, axis), x) * g_max
# return g_x, None
max_and_argmax = MaxAndArgmax()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论