提交 106b8308 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed Rop of argmax that caused test to fail

上级 f5ccee2e
...@@ -1884,7 +1884,22 @@ class MaxAndArgmax(Op): ...@@ -1884,7 +1884,22 @@ class MaxAndArgmax(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None, None] return [None, None]
return [self.make_node(eval_points[0], inputs[1]).outputs[0], None] if not isinstance(inputs[1], theano.Constant):
raise ValueError( ('R_op supported for arg_max only for '
'constant axis!'))
if inputs[1].data > 1:
raise ValueError( ('R_op supported for arg_max only when '
' axis is 0 or 1'))
if inputs[0].ndim != 2:
raise ValueError( ('R_op supported for arg_max only when '
' input is a matrix'))
max_vals, max_pos = self.make_node(*inputs).outputs
if inputs[1].data == 0:
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None]
else:
return [eval_points[0][arange(eval_points[0].shape[0]),
max_pos], None]
def grad(self, inp, grads): def grad(self, inp, grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论