提交 f57a7ccb authored 作者: Harm de Vries's avatar Harm de Vries

uncomment c_code

上级 76609200
......@@ -1369,67 +1369,67 @@ class MaxAndArgmax(Op):
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1), dtype='int64')
#def c_code(self, node, name, inp, out, sub):
#x, axis = inp
#max, argmax = out
#fail = sub["fail"]
#if NoneConst.equals(node.inputs[1]):
#axis_code = "axis = NPY_MAXDIMS;"
#else:
#assert node.inputs[1].ndim == 1
## Fall back to perform() if there are multiple axes
#if len(node.inputs[1].data) > 1: raise NotImplementedError()
#axis_code = """
#axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
#if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
#PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
#%(fail)s
#}
#""" % locals()
#ret = """
#int axis;
#Py_CLEAR(%(max)s);
#Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
#%(axis_code)s
#%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
#if(%(max)s == NULL){
#PyErr_SetString(PyExc_ValueError,
#"MaxAndArgmax, max failed");
#%(fail)s;
#}
#if(!PyArray_CheckExact(%(max)s)){
#%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
#if(%(max)s == NULL){
#%(fail)s;
#}
#}
#%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
#if(%(argmax)s == NULL){
#PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed");
#Py_CLEAR(%(max)s);
#%(fail)s;
#}
#if(!PyArray_CheckExact(%(argmax)s)){
#%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
#if(%(argmax)s == NULL){
#%(fail)s;
#}
#}
#if(PyArray_TYPE(%(argmax)s) != NPY_INT64){
#PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
#if (NULL == tmp){
#%(fail)s;
#}
#Py_DECREF(%(argmax)s);
#%(argmax)s = (PyArrayObject*)tmp;
#}
#"""
#return ret % locals()
#def c_code_cache_version(self):
#return (3,)
def c_code(self, node, name, inp, out, sub):
x, axis = inp
max, argmax = out
fail = sub["fail"]
if NoneConst.equals(node.inputs[1]):
axis_code = "axis = NPY_MAXDIMS;"
else:
assert node.inputs[1].ndim == 1
# Fall back to perform() if there are multiple axes
if len(node.inputs[1].data) > 1: raise NotImplementedError()
axis_code = """
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s
}
""" % locals()
ret = """
int axis;
Py_CLEAR(%(max)s);
Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
%(axis_code)s
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
if(%(max)s == NULL){
PyErr_SetString(PyExc_ValueError,
"MaxAndArgmax, max failed");
%(fail)s;
}
if(!PyArray_CheckExact(%(max)s)){
%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(max)s == NULL){
%(fail)s;
}
}
%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
if(%(argmax)s == NULL){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed");
Py_CLEAR(%(max)s);
%(fail)s;
}
if(!PyArray_CheckExact(%(argmax)s)){
%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(argmax)s == NULL){
%(fail)s;
}
}
if(PyArray_TYPE(%(argmax)s) != NPY_INT64){
PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
if (NULL == tmp){
%(fail)s;
}
Py_DECREF(%(argmax)s);
%(argmax)s = (PyArrayObject*)tmp;
}
"""
return ret % locals()
def c_code_cache_version(self):
return (3,)
def infer_shape(self, node, shapes):
ishape, axis_shape = shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论