提交 f57a7ccb authored 作者: Harm de Vries's avatar Harm de Vries

uncomment c_code

上级 76609200
...@@ -1369,67 +1369,67 @@ class MaxAndArgmax(Op): ...@@ -1369,67 +1369,67 @@ class MaxAndArgmax(Op):
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1), dtype='int64') max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1), dtype='int64')
#def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
#x, axis = inp x, axis = inp
#max, argmax = out max, argmax = out
#fail = sub["fail"] fail = sub["fail"]
#if NoneConst.equals(node.inputs[1]): if NoneConst.equals(node.inputs[1]):
#axis_code = "axis = NPY_MAXDIMS;" axis_code = "axis = NPY_MAXDIMS;"
#else: else:
#assert node.inputs[1].ndim == 1 assert node.inputs[1].ndim == 1
## Fall back to perform() if there are multiple axes # Fall back to perform() if there are multiple axes
#if len(node.inputs[1].data) > 1: raise NotImplementedError() if len(node.inputs[1].data) > 1: raise NotImplementedError()
#axis_code = """ axis_code = """
#axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0]; axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
#if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){ if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
#PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument"); PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
#%(fail)s %(fail)s
#} }
#""" % locals() """ % locals()
#ret = """ ret = """
#int axis; int axis;
#Py_CLEAR(%(max)s); Py_CLEAR(%(max)s);
#Py_CLEAR(%(argmax)s);//todo pass them as out parameter. Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
#%(axis_code)s %(axis_code)s
#%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL); %(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
#if(%(max)s == NULL){ if(%(max)s == NULL){
#PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
#"MaxAndArgmax, max failed"); "MaxAndArgmax, max failed");
#%(fail)s; %(fail)s;
#} }
#if(!PyArray_CheckExact(%(max)s)){ if(!PyArray_CheckExact(%(max)s)){
#%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); %(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
#if(%(max)s == NULL){ if(%(max)s == NULL){
#%(fail)s; %(fail)s;
#} }
#} }
#%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL); %(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
#if(%(argmax)s == NULL){ if(%(argmax)s == NULL){
#PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed"); PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed");
#Py_CLEAR(%(max)s); Py_CLEAR(%(max)s);
#%(fail)s; %(fail)s;
#} }
#if(!PyArray_CheckExact(%(argmax)s)){ if(!PyArray_CheckExact(%(argmax)s)){
#%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); %(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
#if(%(argmax)s == NULL){ if(%(argmax)s == NULL){
#%(fail)s; %(fail)s;
#} }
#} }
#if(PyArray_TYPE(%(argmax)s) != NPY_INT64){ if(PyArray_TYPE(%(argmax)s) != NPY_INT64){
#PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64); PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
#if (NULL == tmp){ if (NULL == tmp){
#%(fail)s; %(fail)s;
#} }
#Py_DECREF(%(argmax)s); Py_DECREF(%(argmax)s);
#%(argmax)s = (PyArrayObject*)tmp; %(argmax)s = (PyArrayObject*)tmp;
#} }
#""" """
#return ret % locals() return ret % locals()
#def c_code_cache_version(self): def c_code_cache_version(self):
#return (3,) return (3,)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
ishape, axis_shape = shapes ishape, axis_shape = shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论