提交 4fbe2385 authored 作者: Frederic's avatar Frederic

Add c_code for MaxAndArgmax

上级 2b806f0b
......@@ -1422,6 +1422,56 @@ class MaxAndArgmax(Op):
dtype=node.outputs[0].dtype)
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64')
def c_code(self, node, name, inp, out, sub):
x, axis = inp
max, argmax = out
fail = sub["fail"]
assert node.inputs[1].ndim == 1
ret = """
int axis;
if(PyArray_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)){
axis = NPY_MAXDIMS;
}else if(PyArray_SIZE(%(axis)s) == 1){
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s
}
}else{
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s;
}
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
if(%(max)s == NULL){
PyErr_SetString(PyExc_ValueError,
"MaxAndArgmax, max failed");
%(fail)s;
}
if(!PyArray_CheckExact(%(max)s)){
%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(max)s == NULL){
%(fail)s;
}
}
%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
if(%(argmax)s == NULL){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed");
Py_CLEAR(%(max)s);
%(fail)s;
}
if(!PyArray_CheckExact(%(argmax)s)){
%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(argmax)s == NULL){
%(fail)s;
}
}
"""
return ret % locals()
def c_code_cache_version(self):
return (1,)
def infer_shape(self, node, shapes):
ishape, axis_shape = shapes
axis = node.inputs[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论