提交 938b47bb authored 作者: Joseph Turian's avatar Joseph Turian

At olivier's advice, uncommented MaxAndArgmax.grad

上级 86880fa9
...@@ -519,10 +519,16 @@ class MaxAndArgmax(Op): ...@@ -519,10 +519,16 @@ class MaxAndArgmax(Op):
def perform(self, node, (x, axis), (max, max_idx)): def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.max(x, axis) max[0] = numpy.max(x, axis)
max_idx[0] = numpy.argmax(x, axis) max_idx[0] = numpy.argmax(x, axis)
# def grad(self, (x, axis), (g_max, g_max_idx)): def grad(self, (x, axis), (g_max, g_max_idx)):
# # This only works if axis is 0, else the max is broadcasted wrong in the call to eq # (x, y), (gz, gw)
# g_x = eq(max(x, axis), x) * g_max # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
# return g_x, None # gMax * dMax/dx + gArgMax * dArgMax/dx, gMax * dMax/daxis + gArgMax * dArgMax/daxis
# g_max has one less dimension than x, so you need to complete g_max to x's shape
# when axis=0 the broadcasting mechanism does it automatically
# This only works if axis is 0, else the max is broadcasted wrong in the call to eq
assert axis.data == 0
g_x = eq(max(x, axis), x) * g_max
return g_x, None
max_and_argmax = MaxAndArgmax() max_and_argmax = MaxAndArgmax()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论