提交 bbf3f5dd authored 作者: Frederic's avatar Frederic

Make MaxAndArgmax work when axis==None

上级 8660a61c
......@@ -1879,13 +1879,17 @@ class MaxAndArgmax(Op):
elif isinstance(axis,(tuple,list)):
assert len(axis)==1,"MaxAndArgmax don't support multiple axis. the max fct support it."
#we make the axis all positive to make the infer_shape work with negative axis
if x.type.ndim>0:
if x.type.ndim>0 and axis is not None:
for id,a in enumerate(axis):
if not isinstance(a, TensorVariable) and a<0:
if -a>x.type.ndim:
raise ValueError('axis out of range')
axis[id]=x.type.ndim+a
axis = _as_tensor_variable(axis)
if axis is None:
axis = _as_tensor_variable(range(x.type.ndim))
else:
axis = _as_tensor_variable(axis)
inputs = [x, axis]
#TODO: figure things out if axis is a constant
broadcastable = [False] * (x.type.ndim - 1)
......@@ -1895,6 +1899,8 @@ class MaxAndArgmax(Op):
def perform(self, node, inp, outs):
x, axis = inp
max, max_idx = outs
if len(axis) == 0 or python_all(axis == range(x.ndim)):
axis = None
max[0] = numpy.asarray(numpy.max(x, axis))
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论