提交 2e991052 authored 作者: Frederic's avatar Frederic

Allow mixing positiv and negativ axis.

上级 c437a2b7
...@@ -2653,7 +2653,10 @@ class MaxAndArgmax(Op): ...@@ -2653,7 +2653,10 @@ class MaxAndArgmax(Op):
axis = [axis] axis = [axis]
elif isinstance(axis, (tuple, list)): elif isinstance(axis, (tuple, list)):
if len(axis) != 1: if len(axis) != 1:
list(axis) axis = list(axis)
for idx in range(len(axis)):
if axis[idx] < 0:
axis[idx] += x.type.ndim
axis.sort() axis.sort()
if axis == range(-x.type.ndim, 0, 1): if axis == range(-x.type.ndim, 0, 1):
axis = range(x.type.ndim) axis = range(x.type.ndim)
......
...@@ -36,7 +36,8 @@ class TestKeepDims: ...@@ -36,7 +36,8 @@ class TestKeepDims:
# 'max_and_argmax' has two outputs and can be specified with either # 'max_and_argmax' has two outputs and can be specified with either
# a single or every axis: # a single or every axis:
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2], for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
[-1], [-2], [-3]]: [-1], [-2], [-3], [-1, -2, -3], [0, -1, -2],
[-2, -3, 2]]:
op = tensor.max_and_argmax op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0]) keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
...@@ -57,7 +58,7 @@ class TestKeepDims: ...@@ -57,7 +58,7 @@ class TestKeepDims:
# axis: # axis:
for op in ([tensor.argmax, tensor.argmin]): for op in ([tensor.argmax, tensor.argmin]):
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2], for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
[-1], [-2], [-3], [-1, -2, -3]]: [-1], [-2], [-3], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
...@@ -79,7 +80,7 @@ class TestKeepDims: ...@@ -79,7 +80,7 @@ class TestKeepDims:
tensor.std, tensor.all, tensor.any, tensor.std, tensor.all, tensor.any,
tensor.max, tensor.min]): tensor.max, tensor.min]):
for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2], for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2],
[-1], [-2], [-3], [-1, -2], [-1, -2, -3]]: [-1], [-2], [-3], [-1, -2], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论