提交 7ed2b04f authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Max and ArgMax R operator

上级 9e9e9ad2
...@@ -1877,6 +1877,12 @@ class MaxAndArgmax(Op): ...@@ -1877,6 +1877,12 @@ class MaxAndArgmax(Op):
rval = tuple([ishape[i] for (i,b) in enumerate(node.inputs[0].type.broadcastable) if i !=axis.data]) rval = tuple([ishape[i] for (i,b) in enumerate(node.inputs[0].type.broadcastable) if i !=axis.data])
return [rval,rval] return [rval,rval]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None, None]
return [self.make_node(eval_points[0], inputs[1]).outputs[0], None]
def grad(self, inp, grads): def grad(self, inp, grads):
# @warning: This only works if axis is 0, else the max is # @warning: This only works if axis is 0, else the max is
# broadcasted wrong in the call to eq. # broadcasted wrong in the call to eq.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论