提交 c4c19aac authored 作者: notoraptor's avatar notoraptor

Correction in local_max_and_argmax:

correction of casting from new MaxAndArgmax to new Argmax.
上级 2a3aa0fa
...@@ -54,17 +54,17 @@ def local_max_and_argmax(node): ...@@ -54,17 +54,17 @@ def local_max_and_argmax(node):
If we don't use the argmax, change it to a max only. If we don't use the argmax, change it to a max only.
""" """
if isinstance(node.op, T.MaxAndArgmax): if isinstance(node.op, T.MaxAndArgmax):
axis = node.op.get_params(node)
if len(node.outputs[1].clients) == 0: if len(node.outputs[1].clients) == 0:
# MaxAndArgmax support variable axis, # MaxAndArgmax support variable axis,
# but CAReduce support only constant axis. # but CAReduce support only constant axis.
# Axis il already constant in the new version of MaxAndArgmax. # Axis il already constant in the new version of MaxAndArgmax.
axis = node.op.get_params(node)
new = CAReduce(scal.maximum, axis)(node.inputs[0]) new = CAReduce(scal.maximum, axis)(node.inputs[0])
return [new, None] return [new, None]
if len(node.outputs[0].clients) == 0: if len(node.outputs[0].clients) == 0:
return [None, T._argmax(node.inputs[0], node.inputs[1])] return [None, T._argmax(node.inputs[0], axis)]
@register_uncanonicalize @register_uncanonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论