提交 5002d0fe authored 作者: Harm de Vries's avatar Harm de Vries

support for multiple axes in MaxAndArgMaxOp

上级 e617dc50
......@@ -1286,28 +1286,19 @@ class MaxAndArgmax(Op):
def make_node(self, x, axis=None):
x = _as_tensor_variable(x)
if isinstance(axis, (tuple, list)):
if isinstance(axis, (tuple, list, numpy.ndarray)):
# List of axes: make them non-negative, and sort them
axis = [int(a) for a in axis]
if len(axis) != 1:
axis = list(axis)
for idx in xrange(len(axis)):
if axis[idx] < 0:
axis[idx] += x.type.ndim
axis.sort()
if axis == list(range(-x.type.ndim, 0, 1)):
axis = list(range(x.type.ndim))
assert axis == list(range(x.type.ndim)), (
"MaxAndArgmax does not support multiple"
" axes. the max fct supports it. Got %s" % axis)
#if axis == list(range(-x.type.ndim, 0, 1)):
#axis = list(range(x.type.ndim))
#assert axis == list(range(x.type.ndim)), (
#"MaxAndArgmax does not support multiple"
#" axes. the max fct supports it. Got %s" % axis)
if axis == list(range(x.type.ndim)):
axis = None
else:
axis = axis[0]
if isinstance(axis, (int, numpy.integer)):
axis = int(axis)
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = int(axis)
elif isinstance(axis, (int, numpy.integer)):
axis = [int(axis)]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = None
......@@ -1317,29 +1308,37 @@ class MaxAndArgmax(Op):
else:
assert (axis.dtype.startswith("int") or
axis.dtype.startswith("uint"))
axis = int(axis.data)
# we make the axis all positive to make the infer_shape work
# with negative axis
if x.type.ndim > 0 and axis is not None:
if axis < 0:
if -axis > x.type.ndim:
raise ValueError('axis out of range')
axis = x.type.ndim + axis
# Verify that the axis is valid.
if isinstance(axis.data, (int, numpy.integer)) or \
(isinstance(axis.data, numpy.ndarray) and axis.data.ndim == 0):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, numpy.ndarray)):
print axis.data
axis = [int(i) for i in axis.data]
# Make axis entries non-negative, and sort them
if isinstance(axis, list):
for idx in xrange(len(axis)):
if axis[idx] < 0:
axis[idx] += x.type.ndim
axis.sort()
# Verify that axes are valid
all_axes = set()
if axis is not None:
if axis < 0 or axis >= x.type.ndim:
raise ValueError(
'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (axis, x.type.ndim))
all_axes.add(axis)
if isinstance(axis, list):
for ax in axis:
if ax < 0 or ax >= x.type.ndim:
raise ValueError(
'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (ax, x.type.ndim))
all_axes.add(ax)
else:
all_axes = list(range(x.ndim))
if axis is None:
if axis is None or axis == list(range(x.type.ndim)):
axis = NoneConst.clone()
else:
axis = _as_tensor_variable(axis)
assert axis.ndim == 0
#assert axis.ndim == 0
inputs = [x, axis]
# We keep the original broadcastable flags for dimensions on which
# we do not perform the max / argmax.
......@@ -1350,21 +1349,28 @@ class MaxAndArgmax(Op):
return Apply(self, inputs, outputs)
def perform(self, node, inp, outs):
x, axis = inp
x, axes = inp
max, max_idx = outs
max[0] = theano._asarray(numpy.max(x, axis),
max[0] = theano._asarray(numpy.max(x, tuple(axes)),
dtype=node.outputs[0].dtype)
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64')
# Numpy does not support multiple axes for argmax
# Work around,
keep_axes = numpy.array([i for i in range(x.ndim) if i not in axes])
# Not reduced axes in front
transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes)))
reshaped_x = x.reshape(transposed_x.shape[:len(keep_axes)] + (-1,))
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, -1), dtype='int64')
def c_code(self, node, name, inp, out, sub):
x, axis = inp
max, argmax = out
fail = sub["fail"]
if NoneConst.equals(node.inputs[1]):
axis_code = "axis = NPY_MAXDIMS;"
else:
assert node.inputs[1].ndim == 0
assert node.inputs[1].ndim == 1
# Fall back to perform() if there are multiple axes
if len(node.inputs[1].data) > 1: raise NotImplementedError()
axis_code = """
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
......@@ -1420,12 +1426,12 @@ class MaxAndArgmax(Op):
def infer_shape(self, node, shapes):
ishape, axis_shape = shapes
axis = node.inputs[1]
if node.inputs[1].data is None:
if axis.data is None:
return [(), ()]
rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i != axis.data])
node.inputs[0].type.broadcastable) if i not in axis.data])
return [rval, rval]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None, None]
......@@ -1484,6 +1490,7 @@ class MaxAndArgmax(Op):
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
......@@ -1492,7 +1499,7 @@ class MaxAndArgmax(Op):
# We are taking the max/argmax over all dimensions.
axis = None
for i in xrange(x.ndim):
if axis is None or i == axis.data:
if axis is None or i in axis.data:
pattern.append('x')
else:
pattern.append(out_dim)
......
......@@ -2963,6 +2963,12 @@ class T_max_and_argmax(unittest.TestCase):
x = tensor.matrix().dimshuffle('x', 0, 'x', 1, 'x')
y = x.max(axis=1)
assert y.type.broadcastable == (True, True, False, True)
def test_multiple_axes(self):
data = as_tensor_variable(numpy.arange(24).reshape(3, 2, 4))
v, i = eval_outputs(max_and_argmax(data, [1, -1]))
assert numpy.all(v == numpy.array([7, 15, 23]))
assert numpy.all(i == numpy.array([7, 7, 7]))
class T_argmin_argmax(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论