提交 b524dcfb authored 作者: abergeron's avatar abergeron

Merge pull request #2932 from nouiz/gpu_advsub1

c code for GpuAdvancedSubtensor1
...@@ -2564,6 +2564,10 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp): ...@@ -2564,6 +2564,10 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
if x_.type.ndim == 0: if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar') raise TypeError('cannot index into a scalar')
# c code suppose it is int64
if x.ndim in [2, 3] and ilist_.dtype in [
'int8', 'int16', 'int32', 'uint8', 'uint16', 'uint32']:
ilist_ = ilist_.cast('int64')
bcast = (ilist_.broadcastable[0],) + x_.broadcastable[1:] bcast = (ilist_.broadcastable[0],) + x_.broadcastable[1:]
return Apply(self, [x_, ilist_], return Apply(self, [x_, ilist_],
[CudaNdarrayType(dtype=x.dtype, [CudaNdarrayType(dtype=x.dtype,
...@@ -2622,6 +2626,38 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp): ...@@ -2622,6 +2626,38 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
o[j] = x[i] o[j] = x[i]
out[0] = o out[0] = o
def c_code(self, node, name, inputs, outputs, sub):
x, idx = inputs
out, = outputs
fail = sub['fail']
if node.inputs[0].ndim not in [2, 3]:
raise NotImplementedError("This case does not have C code yet.")
if node.inputs[1].dtype != 'int64':
raise Exception("Index should have dtype int64. Check this node make_node().")
return """
//take(idx, 0, out, "raise", max_threads);
PyObject * ret = NULL;
PyObject * args = Py_BuildValue("OiOsi", %(idx)s, 0,
%(out)s == NULL ? Py_None : (PyObject *)%(out)s,
"raise", 512);
if(args == NULL){
//Error set by Py_BuildValue
%(fail)s;
}
ret = CudaNdarray_TakeFrom(%(x)s, args);
Py_DECREF(args);
if (ret == NULL){
%(fail)s;
}
// Even if we decref, we still try to reuse preallocated output
Py_XDECREF(%(out)s);
%(out)s = (CudaNdarray *) ret;
""" % locals()
def c_code_cache_version(self):
return (2,)
class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
""" """
......
...@@ -1012,17 +1012,18 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){ ...@@ -1012,17 +1012,18 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
if (verbose) printf("cudandarray indices\n"); if (verbose) printf("cudandarray indices\n");
indices = (CudaNdarray*) indices_obj; indices = (CudaNdarray*) indices_obj;
Py_INCREF(indices); Py_INCREF(indices);
} else if (0 && PyArray_Check(indices_obj)) { } else if (PyArray_Check(indices_obj)) {
PyErr_SetString(PyExc_NotImplementedError, "CudaNdarray_TakeFrom: The indices must cudandarray with float32 value.");
return NULL;
if (verbose) printf("ndarray indices\n"); if (verbose) printf("ndarray indices\n");
if (PyArray_TYPE((PyArrayObject *)indices_obj) != NPY_INT32) { if (PyArray_TYPE((PyArrayObject *)indices_obj) != NPY_INT64) {
PyErr_SetString(PyExc_TypeError, "CudaNdarray_TakeFrom: need a ndarray for indices with dtype int32"); PyErr_SetString(PyExc_TypeError,
"CudaNdarray_TakeFrom: need a ndarray for indices"
" with dtype int64");
return NULL; return NULL;
} }
if (PyArray_NDIM(((PyArrayObject*)indices_obj)) != 1) { if (PyArray_NDIM(((PyArrayObject*)indices_obj)) != 1) {
PyErr_SetString(PyExc_TypeError, "CudaNdarray_TakeFrom: need a CudaNdarray of indices with only 1 dimensions"); PyErr_SetString(PyExc_TypeError,
"CudaNdarray_TakeFrom: need a CudaNdarray of"
" indices with only 1 dimensions");
return NULL; return NULL;
} }
PyArray_Descr* float32_descr = PyArray_DescrFromType(NPY_FLOAT32); PyArray_Descr* float32_descr = PyArray_DescrFromType(NPY_FLOAT32);
...@@ -1031,10 +1032,6 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){ ...@@ -1031,10 +1032,6 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
float32_descr, NULL); float32_descr, NULL);
Py_DECREF(float32_descr); Py_DECREF(float32_descr);
if (verbose) printf("ndarray indices\n"); if (verbose) printf("ndarray indices\n");
//indices_float32 = PyArray_Cast((PyArrayObject*)indices_obj,
// NPY_FLOAT32);
//Py_INCREF(indices_float32);
if (verbose) printf("ndarray indices\n");
if (!indices_float32) if (!indices_float32)
return NULL; return NULL;
...@@ -1047,7 +1044,6 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){ ...@@ -1047,7 +1044,6 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
if (CudaNdarray_CopyFromArray(indices, if (CudaNdarray_CopyFromArray(indices,
(PyArrayObject *)indices_float32)){ (PyArrayObject *)indices_float32)){
Py_DECREF(indices_float32); Py_DECREF(indices_float32);
return NULL; return NULL;
} }
Py_DECREF(indices_float32); Py_DECREF(indices_float32);
...@@ -1076,16 +1072,12 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){ ...@@ -1076,16 +1072,12 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
//Check argument axis //Check argument axis
//TODO: implement the default and other axis //TODO: implement the default and other axis
PyObject * axis_iobj = PyNumber_Long(axis_obj); long axis = PyInt_AsLong(axis_obj);
if (!axis_iobj) {
PyErr_SetString(PyExc_NotImplementedError,"CudaNdarray_TakeFrom: axis must be convertable to a long");
Py_DECREF(indices);
return NULL;
}
long axis = PyInt_AsLong(axis_iobj);
Py_DECREF(axis_iobj); axis_iobj=NULL;
if (axis != 0) { if (axis != 0) {
PyErr_SetString(PyExc_NotImplementedError,"CudaNdarray_TakeFrom: only axis=0 is currently supported"); PyErr_Format(PyExc_NotImplementedError,
"CudaNdarray_TakeFrom: only axis=0 is currently supported."
" Got %ld.", axis);
Py_DECREF(indices); Py_DECREF(indices);
return NULL; return NULL;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论