提交 b4db3c9a authored 作者: Frederic's avatar Frederic

c code for GpuAdvancedSubtensor1

上级 cfbe7c3c
......@@ -2556,6 +2556,10 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar')
# c code suppose it is int64
if x.ndim in [2, 3] and ilist_.dtype in [
'int8', 'int16', 'int32', 'uint8', 'uint16', 'uint32']:
ilist_ = ilist_.cast('int64')
bcast = (ilist_.broadcastable[0],) + x_.broadcastable[1:]
return Apply(self, [x_, ilist_],
[CudaNdarrayType(dtype=x.dtype,
......@@ -2614,6 +2618,52 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
o[j] = x[i]
out[0] = o
def c_code(self, node, name, inputs, outputs, sub):
x, idx = inputs
out, = outputs
fail = sub['fail']
if node.inputs[0].ndim not in [2, 3]:
raise NotImplementedError("This case does not have C code yet.")
if node.inputs[1].dtype != 'int64':
raise Exception("Index should have dtype int64. Check this node make_node().")
return """
int max_threads=512;
//take(idx, 0, out, "raise", max_threads);
PyObject * out_ = NULL;
PyObject * ret = NULL;
Py_INCREF(%(x)s);
PyObject * args = PyTuple_New(5);
PyObject * zero = PyInt_FromLong(0);
PyObject * max = PyInt_FromLong(max_threads);
PyObject * raise = PyString_FromString("raise");
if(args == NULL || zero == NULL || max == NULL || raise == NULL){
%(fail)s;
}
out_ = (PyObject *) %(out)s;
if (out_ == NULL)
out_ = Py_None;
else
Py_INCREF(out_);
PyTuple_SetItem(args, 0, (PyObject *) %(idx)s);
PyTuple_SetItem(args, 1, zero);
PyTuple_SetItem(args, 2, out_);
PyTuple_SetItem(args, 3, raise);
PyTuple_SetItem(args, 4, max);
ret = CudaNdarray_TakeFrom(%(x)s, args);
Py_DECREF(args);
if (ret == NULL){
%(fail)s;
}
// Even if we decref, we still try to reuse preallocated output
Py_XDECREF(%(out)s);
%(out)s = (CudaNdarray *) ret;
""" % locals()
def c_code_cache_version(self):
return (1,)
class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
"""
......
......@@ -1012,17 +1012,18 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
if (verbose) printf("cudandarray indices\n");
indices = (CudaNdarray*) indices_obj;
Py_INCREF(indices);
} else if (0 && PyArray_Check(indices_obj)) {
PyErr_SetString(PyExc_NotImplementedError, "CudaNdarray_TakeFrom: The indices must cudandarray with float32 value.");
return NULL;
} else if (PyArray_Check(indices_obj)) {
if (verbose) printf("ndarray indices\n");
if (PyArray_TYPE((PyArrayObject *)indices_obj) != NPY_INT32) {
PyErr_SetString(PyExc_TypeError, "CudaNdarray_TakeFrom: need a ndarray for indices with dtype int32");
if (PyArray_TYPE((PyArrayObject *)indices_obj) != NPY_INT64) {
PyErr_SetString(PyExc_TypeError,
"CudaNdarray_TakeFrom: need a ndarray for indices"
" with dtype int64");
return NULL;
}
if (PyArray_NDIM(((PyArrayObject*)indices_obj)) != 1) {
PyErr_SetString(PyExc_TypeError, "CudaNdarray_TakeFrom: need a CudaNdarray of indices with only 1 dimensions");
PyErr_SetString(PyExc_TypeError,
"CudaNdarray_TakeFrom: need a CudaNdarray of"
" indices with only 1 dimensions");
return NULL;
}
PyArray_Descr* float32_descr = PyArray_DescrFromType(NPY_FLOAT32);
......@@ -1031,10 +1032,6 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
float32_descr, NULL);
Py_DECREF(float32_descr);
if (verbose) printf("ndarray indices\n");
//indices_float32 = PyArray_Cast((PyArrayObject*)indices_obj,
// NPY_FLOAT32);
//Py_INCREF(indices_float32);
if (verbose) printf("ndarray indices\n");
if (!indices_float32)
return NULL;
......@@ -1047,7 +1044,6 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
if (CudaNdarray_CopyFromArray(indices,
(PyArrayObject *)indices_float32)){
Py_DECREF(indices_float32);
return NULL;
}
Py_DECREF(indices_float32);
......@@ -1076,16 +1072,12 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
//Check argument axis
//TODO: implement the default and other axis
PyObject * axis_iobj = PyNumber_Long(axis_obj);
if (!axis_iobj) {
PyErr_SetString(PyExc_NotImplementedError,"CudaNdarray_TakeFrom: axis must be convertable to a long");
Py_DECREF(indices);
return NULL;
}
long axis = PyInt_AsLong(axis_iobj);
Py_DECREF(axis_iobj); axis_iobj=NULL;
long axis = PyInt_AsLong(axis_obj);
if (axis != 0) {
PyErr_SetString(PyExc_NotImplementedError,"CudaNdarray_TakeFrom: only axis=0 is currently supported");
PyErr_Format(PyExc_NotImplementedError,
"CudaNdarray_TakeFrom: only axis=0 is currently supported."
" Got %ld.", axis);
Py_DECREF(indices);
return NULL;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论