modified grad of MaxAndArgMax to support dim=0 or dim=ndim-1

TODO: implementation for any axis is pretty straightforward (pad_shapeleft + DimShuffle)
上级 46dd9a6b
......@@ -809,9 +809,16 @@ class MaxAndArgmax(Op):
# gMax * dMax/dx + gArgMax * dArgMax/dx, gMax * dMax/daxis + gArgMax * dArgMax/daxis
# g_max has one less dimension than x, so you need to complete g_max to x's shape
# when axis=0 the broadcasting mechanism does it automatically
assert axis.data == 0
g_x = eq(max(x, axis), x) * g_max
assert axis.data == 0 or axis.data == x.ndim-1
g_max_pad = shape_padleft(g_max) if axis.data==0 else \
shape_padright(g_max)
xmax = max(x, axis)
xmax_pad = shape_padleft(xmax) if axis.data==0 else \
shape_padright(xmax)
g_x = eq(xmax_pad, x) * g_max_pad
return g_x, None
@_redefine_asRoutine(MaxAndArgmax())
def max_and_argmax(a):
pass
......@@ -1624,7 +1631,7 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Join),
@constructor
def shape_padleft(tensor, n_ones):
def shape_padleft(tensor, n_ones=1):
"""Reshape `tensor` by left-padding the shape with `n_ones` 1s
See also: `shape_padright` and `Dimshuffle`
......@@ -1639,7 +1646,7 @@ def rightpad_shape(tensor, n_ones):
return DimShuffle(tensor.broadcastable, pattern)(tensor)
@constructor
def shape_padright(tensor, n_ones):
def shape_padright(tensor, n_ones=1):
"""Reshape `tensor` by right-padding the shape with `n_ones` 1s
See also: `shape_padleft` and `Dimshuffle`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论