提交 ca30b831 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5155 from sygi/argmax-only-cpu

Option to evaluate part of the max-and-argmax on cpu
...@@ -1419,6 +1419,170 @@ class MaxAndArgmax(Op): ...@@ -1419,6 +1419,170 @@ class MaxAndArgmax(Op):
_max_and_argmax = MaxAndArgmax() _max_and_argmax = MaxAndArgmax()
class Argmax(Op):
"""
Calculate the argmax over a given axis or over all axes.
"""
nin = 2 # tensor, axis
nout = 1
E_axis = 'invalid axis'
__props__ = ()
def make_node(self, x, axis=None):
x = _as_tensor_variable(x)
if isinstance(axis, (integer_types, numpy.integer)):
axis = [int(axis)]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, (tuple, list, numpy.ndarray)):
axis = [int(a) for a in axis]
if axis == list(range(x.type.ndim)):
axis = None
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = None
elif not isinstance(axis, TensorConstant):
raise TypeError(
"Argmax needs a constant axis. Got %s" % axis)
else:
assert (axis.dtype.startswith("int") or
axis.dtype.startswith("uint"))
if isinstance(axis.data, (integer_types, numpy.integer)) or \
(isinstance(axis.data, numpy.ndarray) and
axis.data.ndim == 0):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, numpy.ndarray)):
axis = [int(i) for i in axis.data]
# Make axis entries non-negative, and sort them
if isinstance(axis, list):
for idx in xrange(len(axis)):
if axis[idx] < 0:
axis[idx] += x.type.ndim
axis.sort()
# Verify that axes are valid
all_axes = []
if isinstance(axis, list):
for ax in axis:
if ax < 0 or ax >= x.type.ndim:
raise ValueError(
'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (ax, x.type.ndim))
if ax not in all_axes:
all_axes.append(ax)
else:
all_axes = list(range(x.ndim))
if axis is None or axis == list(range(x.type.ndim)):
axis = NoneConst.clone()
else:
axis = _as_tensor_variable(all_axes)
assert axis.ndim == 1
inputs = [x, axis]
# We keep the original broadcastable flags for dimensions on which
# we do not perform the argmax.
broadcastable = [b for i, b in enumerate(x.type.broadcastable)
if i not in all_axes]
outputs = [tensor('int64', broadcastable, name='argmax')]
return Apply(self, inputs, outputs)
def perform(self, node, inp, outs):
x, axes = inp
max_idx, = outs
if axes is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axes)
# Numpy does not support multiple axes for argmax
# Work around
keep_axes = numpy.array([i for i in range(x.ndim) if i not in axes],
dtype='int64')
# Not-reduced axes in front
transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes,
axes)))
kept_shape = transposed_x.shape[:len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes):]
new_shape = kept_shape + (numpy.prod(reduced_shape),)
reshaped_x = transposed_x.reshape(new_shape)
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1),
dtype='int64')
def c_code(self, node, name, inp, out, sub):
x, axis = inp
argmax, = out
fail = sub["fail"]
if NoneConst.equals(node.inputs[1]):
axis_code = "axis = NPY_MAXDIMS;"
else:
assert node.inputs[1].ndim == 1
# Fall back to perform() if there are multiple axes
if len(node.inputs[1].data) > 1:
raise NotImplementedError()
axis_code = """
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument");
%(fail)s
}
""" % locals()
ret = """
int axis;
Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
%(axis_code)s
%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
if(%(argmax)s == NULL){
%(fail)s;
}
if(!PyArray_CheckExact(%(argmax)s)){
%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(argmax)s == NULL){
%(fail)s;
}
}
if(PyArray_TYPE(%(argmax)s) != NPY_INT64){
PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
if (NULL == tmp){
%(fail)s;
}
Py_DECREF(%(argmax)s);
%(argmax)s = (PyArrayObject*)tmp;
}
"""
return ret % locals()
def c_code_cache_version(self):
return (0,)
def infer_shape(self, node, shapes):
ishape, axis_shape = shapes
axis = node.inputs[1]
if axis.data is None:
return [()]
rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i not in axis.data])
return [rval]
def grad(self, inp, grads):
x, axis = inp
axis_grad = grad_undefined(
self, 1, axis,
"argmax is not defined for non-integer axes so"
" argmax(x, axis+eps) is undefined")
return [x.zeros_like(), axis_grad]
_argmax = Argmax()
def makeKeepDims(x, y, axis): def makeKeepDims(x, y, axis):
""" """
Reintroduces in y with length one the axes of x which have been left out Reintroduces in y with length one the axes of x which have been left out
...@@ -1541,9 +1705,6 @@ def argmax(x, axis=None, keepdims=False): ...@@ -1541,9 +1705,6 @@ def argmax(x, axis=None, keepdims=False):
will broadcast correctly against the original tensor. will broadcast correctly against the original tensor.
""" """
# In python (using MaxAndArgmax.perform()) this leads to a wasteful
# implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine.
argout = max_and_argmax(x, axis)[1] argout = max_and_argmax(x, axis)[1]
if keepdims: if keepdims:
......
...@@ -1314,10 +1314,10 @@ def test_argmax_pushdown(): ...@@ -1314,10 +1314,10 @@ def test_argmax_pushdown():
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 2 # an output_guard is second assert len(fgraph.toposort()) == 2 # an output_guard is second
assert fgraph.toposort()[0].op == tensor.basic._max_and_argmax assert fgraph.toposort()[0].op == tensor.basic._argmax
assert str(fgraph.toposort()[1].op) == 'OutputGuard' assert str(fgraph.toposort()[1].op) == 'OutputGuard'
assert check_stack_trace( assert check_stack_trace(
fgraph, ops_to_check=tensor.basic._max_and_argmax) fgraph, ops_to_check=tensor.basic._argmax)
x = tensor.matrix() x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used # test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax( out = tensor.max_and_argmax(
...@@ -1362,7 +1362,7 @@ def test_argmax_pushdown_bias(): ...@@ -1362,7 +1362,7 @@ def test_argmax_pushdown_bias():
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.MaxAndArgmax) types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax)
assert len(fgraph.toposort()) == 4 assert len(fgraph.toposort()) == 4
for i, type in enumerate(types_to_check): for i, type in enumerate(types_to_check):
assert isinstance(fgraph.toposort()[i].op, type) assert isinstance(fgraph.toposort()[i].op, type)
......
...@@ -73,6 +73,9 @@ def local_max_and_argmax(node): ...@@ -73,6 +73,9 @@ def local_max_and_argmax(node):
new = CAReduce(scal.maximum, axis)(node.inputs[0]) new = CAReduce(scal.maximum, axis)(node.inputs[0])
return [new, None] return [new, None]
if len(node.outputs[0].clients) == 0:
return [None, T._argmax(node.inputs[0], node.inputs[1])]
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.neg]) @gof.local_optimizer([T.neg])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论