提交 b675d135 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move argmax helper close to class definition

上级 8a226a76
......@@ -277,6 +277,28 @@ class Argmax(COp):
return [x.zeros_like()]
def argmax(x, axis=None, keepdims=False):
"""
Returns indices of maximum elements obtained by iterating over given axis.
When axis is None (the default value), the argmax is performed
over the flattened tensor.
Parameters
----------
keepdims : bool
If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
"""
argout = max_and_argmax(x, axis)[1]
if keepdims:
argout = makeKeepDims(x, argout, axis)
return argout
@_vectorize_node.register(Argmax)
def vectorize_argmax_node(op, node, batch_x):
core_ndim = node.inputs[0].type.ndim
......@@ -549,28 +571,6 @@ def max(x, axis=None, keepdims=False):
return out
def argmax(x, axis=None, keepdims=False):
"""
Returns indices of maximum elements obtained by iterating over given axis.
When axis is None (the default value), the argmax is performed
over the flattened tensor.
Parameters
----------
keepdims : bool
If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
"""
argout = max_and_argmax(x, axis)[1]
if keepdims:
argout = makeKeepDims(x, argout, axis)
return argout
def min(x, axis=None, keepdims=False):
"""
Returns minimum elements obtained by iterating over given axis.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论