提交 c4c19aac authored 作者: notoraptor's avatar notoraptor

Correction in local_max_and_argmax:

correction of casting from new MaxAndArgmax to new Argmax.
上级 2a3aa0fa
......@@ -54,17 +54,17 @@ def local_max_and_argmax(node):
If we don't use the argmax, change it to a max only.
"""
if isinstance(node.op, T.MaxAndArgmax):
axis = node.op.get_params(node)
if len(node.outputs[1].clients) == 0:
# MaxAndArgmax support variable axis,
# but CAReduce support only constant axis.
# Axis il already constant in the new version of MaxAndArgmax.
axis = node.op.get_params(node)
new = CAReduce(scal.maximum, axis)(node.inputs[0])
return [new, None]
if len(node.outputs[0].clients) == 0:
return [None, T._argmax(node.inputs[0], node.inputs[1])]
return [None, T._argmax(node.inputs[0], axis)]
@register_uncanonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论