提交 c437a2b7 authored 作者: Frederic's avatar Frederic

allow keepdims with negative index.

Also also MaxAndArgmax with all neg axis. fix gh-1430
上级 941a91e4
...@@ -2655,6 +2655,8 @@ class MaxAndArgmax(Op): ...@@ -2655,6 +2655,8 @@ class MaxAndArgmax(Op):
if len(axis) != 1: if len(axis) != 1:
list(axis) list(axis)
axis.sort() axis.sort()
if axis == range(-x.type.ndim, 0, 1):
axis = range(x.type.ndim)
assert axis == range(x.type.ndim), ( assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple" "MaxAndArgmax does not support multiple"
" axes. the max fct supports it.") " axes. the max fct supports it.")
...@@ -2806,10 +2808,17 @@ def makeKeepDims(x, y, axis): ...@@ -2806,10 +2808,17 @@ def makeKeepDims(x, y, axis):
axis = range(x.type.ndim) axis = range(x.type.ndim)
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
newaxis = []
for a in axis:
if not isinstance(a, int):
raise ValueError("keepdims option can be used only with constant axis")
if a < 0:
a += x.type.ndim
newaxis.append(a)
i = 0 i = 0
new_dims = [] new_dims = []
for j, _ in enumerate(x.type.broadcastable): for j, _ in enumerate(x.type.broadcastable):
if j in axis: if j in newaxis:
new_dims.append('x') new_dims.append('x')
else: else:
new_dims.append(i) new_dims.append(i)
......
...@@ -13,9 +13,14 @@ class TestKeepDims: ...@@ -13,9 +13,14 @@ class TestKeepDims:
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
i = 0 i = 0
newaxis = []
for a in axis:
if a < 0:
a += x.type.ndim
newaxis.append(a)
new_dims = [] new_dims = []
for j, _ in enumerate(x.shape): for j, _ in enumerate(x.shape):
if j in axis: if j in newaxis:
new_dims.append('x') new_dims.append('x')
else: else:
new_dims.append(i) new_dims.append(i)
...@@ -30,7 +35,8 @@ class TestKeepDims: ...@@ -30,7 +35,8 @@ class TestKeepDims:
# 'max_and_argmax' has two outputs and can be specified with either # 'max_and_argmax' has two outputs and can be specified with either
# a single or every axis: # a single or every axis:
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2]]: for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
[-1], [-2], [-3]]:
op = tensor.max_and_argmax op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0]) keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
...@@ -50,8 +56,8 @@ class TestKeepDims: ...@@ -50,8 +56,8 @@ class TestKeepDims:
# the following ops can be specified with either a single axis or every # the following ops can be specified with either a single axis or every
# axis: # axis:
for op in ([tensor.argmax, tensor.argmin]): for op in ([tensor.argmax, tensor.argmin]):
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2]]: [-1], [-2], [-3], [-1, -2, -3]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
...@@ -72,7 +78,8 @@ class TestKeepDims: ...@@ -72,7 +78,8 @@ class TestKeepDims:
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var, for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std, tensor.all, tensor.any, tensor.std, tensor.all, tensor.any,
tensor.max, tensor.min]): tensor.max, tensor.min]):
for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]: for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2],
[-1], [-2], [-3], [-1, -2], [-1, -2, -3]]:
keep_param = function([x], op(x, axis=axis, keepdims=True)) keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x, keep_synth = function([x], self.makeKeepDims_local(x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论