提交 ddc7ed83 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1673 from nouiz/argmax

[MRG] Add c_code for MaxAndArgmax
...@@ -1709,6 +1709,7 @@ class GCC_compiler(object): ...@@ -1709,6 +1709,7 @@ class GCC_compiler(object):
# numpy 1.7 deprecated the following macro but the new one didn't # numpy 1.7 deprecated the following macro but the new one didn't
# existed in the past # existed in the past
if bool(numpy_ver < [1, 7]): if bool(numpy_ver < [1, 7]):
cxxflags.append("-D NPY_ARRAY_ENSUREARRAY=NPY_ENSUREARRAY")
cxxflags.append("-D NPY_ARRAY_ENSURECOPY=NPY_ENSURECOPY") cxxflags.append("-D NPY_ARRAY_ENSURECOPY=NPY_ENSURECOPY")
cxxflags.append("-D NPY_ARRAY_ALIGNED=NPY_ALIGNED") cxxflags.append("-D NPY_ARRAY_ALIGNED=NPY_ALIGNED")
cxxflags.append("-D NPY_ARRAY_WRITEABLE=NPY_WRITEABLE") cxxflags.append("-D NPY_ARRAY_WRITEABLE=NPY_WRITEABLE")
......
...@@ -19,6 +19,7 @@ from theano.tensor.var import (AsTensorError, TensorVariable, ...@@ -19,6 +19,7 @@ from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant, TensorConstant,
_tensor_py_operators) _tensor_py_operators)
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
from theano.tensor.type_other import NoneConst
from theano import scalar as scal from theano import scalar as scal
from theano.gof.python25 import partial, any, all from theano.gof.python25 import partial, any, all
from theano.gof.utils import hashtype from theano.gof.utils import hashtype
...@@ -1366,11 +1367,7 @@ class MaxAndArgmax(Op): ...@@ -1366,11 +1367,7 @@ class MaxAndArgmax(Op):
def make_node(self, x, axis=None): def make_node(self, x, axis=None):
x = _as_tensor_variable(x) x = _as_tensor_variable(x)
if isinstance(axis, (int, numpy.integer)): if isinstance(axis, (tuple, list)):
axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, (tuple, list)):
axis = [int(a) for a in axis] axis = [int(a) for a in axis]
if len(axis) != 1: if len(axis) != 1:
axis = list(axis) axis = list(axis)
...@@ -1383,34 +1380,41 @@ class MaxAndArgmax(Op): ...@@ -1383,34 +1380,41 @@ class MaxAndArgmax(Op):
assert axis == range(x.type.ndim), ( assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple" "MaxAndArgmax does not support multiple"
" axes. the max fct supports it.") " axes. the max fct supports it.")
axis = None
else:
axis = axis[0]
if isinstance(axis, (int, numpy.integer)):
axis = int(axis)
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = int(axis)
elif isinstance(axis, Variable): elif isinstance(axis, Variable):
if not isinstance(axis, TensorConstant): if not isinstance(axis, TensorConstant):
raise TypeError("MaxAndArgmax needs a constant axis") raise TypeError("MaxAndArgmax needs a constant axis")
axis = axis.data assert axis.dtype.startswith("int") or axis.dtype.startswith("uint")
if axis.ndim == 0: axis = int(axis.data)
axis = [axis]
# we make the axis all positive to make the infer_shape work # we make the axis all positive to make the infer_shape work
# with negative axis # with negative axis
if x.type.ndim > 0 and axis is not None: if x.type.ndim > 0 and axis is not None:
for id, a in enumerate(axis): if axis < 0:
if not isinstance(a, TensorVariable) and a < 0: if -axis > x.type.ndim:
if -a > x.type.ndim: raise ValueError('axis out of range')
raise ValueError('axis out of range') axis = x.type.ndim + axis
axis[id] = x.type.ndim + a
if axis is None:
axis = _as_tensor_variable(range(x.type.ndim))
else:
axis = _as_tensor_variable(axis)
# Verify that the axis is valid. # Verify that the axis is valid.
all_axes = set() all_axes = set()
for ax in axis.data: if axis is not None:
if ax < 0 or ax >= x.type.ndim: if axis < 0 or axis >= x.type.ndim:
raise ValueError( raise ValueError(
'Invalid axis: %s (the number of dimensions of the ' 'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (axis, x.type.ndim)) 'input is: %s)' % (axis, x.type.ndim))
all_axes.add(ax.item()) all_axes.add(axis)
assert axis.ndim == 1 else:
all_axes = range(x.ndim)
if axis is None:
axis = NoneConst.clone()
else:
axis = _as_tensor_variable(axis)
assert axis.ndim == 0
inputs = [x, axis] inputs = [x, axis]
# We keep the original broadcastable flags for dimensions on which # We keep the original broadcastable flags for dimensions on which
# we do not perform the max / argmax. # we do not perform the max / argmax.
...@@ -1423,16 +1427,61 @@ class MaxAndArgmax(Op): ...@@ -1423,16 +1427,61 @@ class MaxAndArgmax(Op):
def perform(self, node, inp, outs): def perform(self, node, inp, outs):
x, axis = inp x, axis = inp
max, max_idx = outs max, max_idx = outs
if python_all(axis == range(x.ndim)):
axis = None
max[0] = theano._asarray(numpy.max(x, axis), max[0] = theano._asarray(numpy.max(x, axis),
dtype=node.outputs[0].dtype) dtype=node.outputs[0].dtype)
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64') max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64')
def c_code(self, node, name, inp, out, sub):
x, axis = inp
max, argmax = out
fail = sub["fail"]
assert NoneConst.equals(node.inputs[1]) or node.inputs[1].ndim == 0
ret = """
int axis;
if((PyObject*)%(axis)s == Py_None){
axis = NPY_MAXDIMS;
}else{
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s
}
}
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
if(%(max)s == NULL){
PyErr_SetString(PyExc_ValueError,
"MaxAndArgmax, max failed");
%(fail)s;
}
if(!PyArray_CheckExact(%(max)s)){
%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(max)s == NULL){
%(fail)s;
}
}
%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
if(%(argmax)s == NULL){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed");
Py_CLEAR(%(max)s);
%(fail)s;
}
if(!PyArray_CheckExact(%(argmax)s)){
%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(argmax)s == NULL){
%(fail)s;
}
}
"""
return ret % locals()
def c_code_cache_version(self):
return (1,)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
ishape, axis_shape = shapes ishape, axis_shape = shapes
axis = node.inputs[1] axis = node.inputs[1]
if python_all(axis.data == range(node.inputs[0].ndim)): if node.inputs[1].data is None:
return [(), ()] return [(), ()]
rval = tuple([ishape[i] for (i, b) in enumerate( rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i != axis.data]) node.inputs[0].type.broadcastable) if i != axis.data])
...@@ -1490,12 +1539,16 @@ class MaxAndArgmax(Op): ...@@ -1490,12 +1539,16 @@ class MaxAndArgmax(Op):
# the gradient on its inputs is zero # the gradient on its inputs is zero
if g_max_disconnected: if g_max_disconnected:
return [x.zeros_like(), axis_grad] return [x.zeros_like(), axis_grad]
xmax = max(x, axis) if NoneConst.equals(axis):
axis_ = range(x.ndim)
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input. # Raise the g_max and xmax to the same number of dim as the input.
pattern = [] pattern = []
out_dim = 0 out_dim = 0
if python_all(axis.data == range(x.ndim)): if NoneConst.equals(axis):
# We are taking the max/argmax over all dimensions. # We are taking the max/argmax over all dimensions.
axis = None axis = None
for i in range(x.ndim): for i in range(x.ndim):
......
...@@ -3313,6 +3313,7 @@ def local_cut_useless_reduce(node): ...@@ -3313,6 +3313,7 @@ def local_cut_useless_reduce(node):
# see gh-790 issue. # see gh-790 issue.
# #
#@register_canonicalize #@register_canonicalize
@register_uncanonicalize
@register_specialize @register_specialize
@gof.local_optimizer(ALL_REDUCE) @gof.local_optimizer(ALL_REDUCE)
def local_reduce_broadcastable(node): def local_reduce_broadcastable(node):
......
...@@ -46,10 +46,13 @@ def local_max_and_argmax(node): ...@@ -46,10 +46,13 @@ def local_max_and_argmax(node):
if len(node.outputs[1].clients) == 0: if len(node.outputs[1].clients) == 0:
#MaxAndArgmax support variable axis, #MaxAndArgmax support variable axis,
#but CAReduce support only constant axis. #but CAReduce support only constant axis.
try: if node.inputs[1].data is None:
axis = get_scalar_constant_value(node.inputs[1]) axis = None
except NotScalarConstantError: else:
return False try:
axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError:
return False
new = CAReduce(scal.maximum, axis)(node.inputs[0]) new = CAReduce(scal.maximum, axis)(node.inputs[0])
return [new, None] return [new, None]
......
...@@ -1727,7 +1727,7 @@ advanced_inc_subtensor1 = AdvancedIncSubtensor1() ...@@ -1727,7 +1727,7 @@ advanced_inc_subtensor1 = AdvancedIncSubtensor1()
def as_index_variable(idx): def as_index_variable(idx):
if idx is None: if idx is None:
return NoneConst return NoneConst.clone()
if isinstance(idx, slice): if isinstance(idx, slice):
return make_slice(idx) return make_slice(idx)
idx = theano.tensor.as_tensor_variable(idx) idx = theano.tensor.as_tensor_variable(idx)
......
...@@ -66,4 +66,7 @@ class NoneTypeT(Type): ...@@ -66,4 +66,7 @@ class NoneTypeT(Type):
def __str__(self): def __str__(self):
return "None" return "None"
# This is a variable instance. It can be used only once per fgraph.
# So use NoneConst.clone() before using it in a Theano graph.
# Use NoneConst.equal(x) to check if two variable are NoneConst.
NoneConst = Constant(NoneTypeT(), None, name='None') NoneConst = Constant(NoneTypeT(), None, name='None')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论