提交 cbf387b7 authored 作者: Frederic's avatar Frederic

Make MaxAndArgmax accept None as axis and update test to tests more axis.

上级 5ad3c667
...@@ -1885,8 +1885,12 @@ class MaxAndArgmax(Op): ...@@ -1885,8 +1885,12 @@ class MaxAndArgmax(Op):
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis] axis = [axis]
elif isinstance(axis, (tuple, list)): elif isinstance(axis, (tuple, list)):
assert len(axis) == 1, ("MaxAndArgmax don't support multiple" if len(axis) != 1:
" axis. the max fct support it.") list(axis)
axis.sort()
assert axis == range(x.type.ndim), (
"MaxAndArgmax don't support multiple"
" axis. the max fct support it.")
# we make the axis all positive to make the infer_shape work # we make the axis all positive to make the infer_shape work
# with negative axis # with negative axis
if x.type.ndim > 0 and axis is not None: if x.type.ndim > 0 and axis is not None:
...@@ -1901,8 +1905,7 @@ class MaxAndArgmax(Op): ...@@ -1901,8 +1905,7 @@ class MaxAndArgmax(Op):
axis = _as_tensor_variable(axis) axis = _as_tensor_variable(axis)
inputs = [x, axis] inputs = [x, axis]
#TODO: figure things out if axis is a constant broadcastable = [False] * (x.type.ndim - len(axis.data))
broadcastable = [False] * (x.type.ndim - 1)
outputs = [tensor(x.type.dtype, broadcastable, name='max'), outputs = [tensor(x.type.dtype, broadcastable, name='max'),
tensor('int32', broadcastable, name='argmax')] tensor('int32', broadcastable, name='argmax')]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
...@@ -1920,6 +1923,10 @@ class MaxAndArgmax(Op): ...@@ -1920,6 +1923,10 @@ class MaxAndArgmax(Op):
axis = node.inputs[1] axis = node.inputs[1]
if axis is None: if axis is None:
return [(), ()] return [(), ()]
elif len(axis.data) == 0 and node.inputs[0].ndim:
return [(1,), (1,)]
elif python_all(axis.data == range(node.inputs[0].ndim)):
return [(), ()]
rval = tuple([ishape[i] for (i, b) in enumerate( rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i != axis.data]) node.inputs[0].type.broadcastable) if i != axis.data])
return [rval, rval] return [rval, rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论