提交 0af0b1a1 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make sure the index array is contiguous before changing dtype

上级 e6a4c073
...@@ -1026,25 +1026,41 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){ ...@@ -1026,25 +1026,41 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
" indices with only 1 dimensions"); " indices with only 1 dimensions");
return NULL; return NULL;
} }
// We need indices_obj to be contiguous, in order to take a view
// with a different dtype.
if (!PyArray_IS_C_CONTIGUOUS((PyArrayObject*) indices_obj)) {
PyObject* indices_obj_contig = PyArray_NewCopy((PyArrayObject*) indices_obj, NPY_CORDER);
if (!indices_obj_contig)
return NULL;
indices_obj = indices_obj_contig;
} else {
// Keep the refcount consistent
Py_INCREF(indices_obj);
}
PyArray_Descr* float32_descr = PyArray_DescrFromType(NPY_FLOAT32); PyArray_Descr* float32_descr = PyArray_DescrFromType(NPY_FLOAT32);
PyObject * indices_float32 = NULL; PyObject * indices_float32 = NULL;
indices_float32 = PyArray_View((PyArrayObject*)indices_obj, indices_float32 = PyArray_View((PyArrayObject*)indices_obj,
float32_descr, NULL); float32_descr, NULL);
if (verbose) printf("ndarray indices\n"); if (verbose) printf("ndarray indices\n");
if (!indices_float32) if (!indices_float32) {
Py_DECREF(indices_obj);
return NULL; return NULL;
}
indices = (CudaNdarray*) CudaNdarray_New(); indices = (CudaNdarray*) CudaNdarray_New();
if (verbose) printf("\nndarray after new\n"); if (verbose) printf("\nndarray after new\n");
if (! indices){ if (! indices){
Py_DECREF(indices_obj);
Py_DECREF(indices_float32); Py_DECREF(indices_float32);
return NULL; return NULL;
} }
if (CudaNdarray_CopyFromArray(indices, if (CudaNdarray_CopyFromArray(indices,
(PyArrayObject *)indices_float32)){ (PyArrayObject *)indices_float32)){
Py_DECREF(indices_obj);
Py_DECREF(indices_float32); Py_DECREF(indices_float32);
return NULL; return NULL;
} }
Py_DECREF(indices_obj);
Py_DECREF(indices_float32); Py_DECREF(indices_float32);
} else { } else {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论