提交 49ecd787 authored 作者: Frederic's avatar Frederic

Use the new numpy interface.

上级 60b5ccc2
...@@ -498,7 +498,7 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args) ...@@ -498,7 +498,7 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args)
if (!rval){ if (!rval){
return NULL; return NULL;
} }
assert (PyArray_ITEMSIZE(rval) == sizeof(real)); assert (PyArray_ITEMSIZE((PyArrayObject *)rval) == sizeof(real));
return rval; return rval;
} }
if ((self->nd < 0) || (self->devdata == 0)) if ((self->nd < 0) || (self->devdata == 0))
...@@ -527,7 +527,9 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args) ...@@ -527,7 +527,9 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args)
assert (npydims); assert (npydims);
for (int i = 0; i < self->nd; ++i) for (int i = 0; i < self->nd; ++i)
npydims[i] = (npy_intp)(CudaNdarray_HOST_DIMS(self)[i]); npydims[i] = (npy_intp)(CudaNdarray_HOST_DIMS(self)[i]);
PyObject * rval = PyArray_SimpleNew(self->nd, npydims, REAL_TYPENUM); PyArrayObject * rval = (PyArrayObject *) PyArray_SimpleNew(self->nd,
npydims,
REAL_TYPENUM);
free(npydims); free(npydims);
if (!rval) if (!rval)
{ {
...@@ -555,7 +557,7 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args) ...@@ -555,7 +557,7 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args)
} }
Py_DECREF(contiguous_self); Py_DECREF(contiguous_self);
return rval; return (PyObject *)rval;
} }
// TODO-- we have two functions here, ZEROS and Zeros. // TODO-- we have two functions here, ZEROS and Zeros.
...@@ -978,7 +980,7 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){ ...@@ -978,7 +980,7 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args){
return NULL; return NULL;
if (verbose) printf("ndarray indices\n"); if (verbose) printf("ndarray indices\n");
if (PyArray_TYPE(indices_obj) != NPY_INT32) { if (PyArray_TYPE((PyArrayObject *)indices_obj) != NPY_INT32) {
PyErr_SetString(PyExc_TypeError, "CudaNdarray_TakeFrom: need a ndarray for indices with dtype int32"); PyErr_SetString(PyExc_TypeError, "CudaNdarray_TakeFrom: need a ndarray for indices with dtype int32");
return NULL; return NULL;
} }
...@@ -3357,7 +3359,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s ...@@ -3357,7 +3359,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
} }
for (int i = 0; i < PyArray_NDIM(data); ++i) for (int i = 0; i < PyArray_NDIM(data); ++i)
{ {
if ((data->dimensions[i] > 1) && PyInt_AsLong(PyTuple_GetItem(broadcastable, Py_ssize_t(i)))) if ((PyArray_DIMS(data)[i] > 1) && PyInt_AsLong(PyTuple_GetItem(broadcastable, Py_ssize_t(i))))
{ {
PyErr_Format(PyExc_TypeError, "Non-unit size in broadcastable dimension %i", i); PyErr_Format(PyExc_TypeError, "Non-unit size in broadcastable dimension %i", i);
Py_DECREF(data); Py_DECREF(data);
...@@ -3578,7 +3580,8 @@ cublas_shutdown() ...@@ -3578,7 +3580,8 @@ cublas_shutdown()
int int
CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj) CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
{ {
int err = CudaNdarray_alloc_contiguous(self, PyArray_NDIM(obj), obj->dimensions); int err = CudaNdarray_alloc_contiguous(self, PyArray_NDIM(obj),
PyArray_DIMS(obj));
if (err) { if (err) {
return err; return err;
} }
...@@ -3590,7 +3593,8 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj) ...@@ -3590,7 +3593,8 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
return -1; return -1;
} }
assert( 4 == PyArray_ITEMSIZE(obj)); assert( 4 == PyArray_ITEMSIZE(obj));
PyObject * py_src = PyArray_ContiguousFromAny((PyObject*)obj, typenum, self->nd, self->nd); PyArrayObject * py_src = (PyArrayObject *)PyArray_ContiguousFromAny(
(PyObject*)obj, typenum, self->nd, self->nd);
if (!py_src) { if (!py_src) {
return -1; return -1;
} }
......
...@@ -156,6 +156,12 @@ class NVCC_compiler(object): ...@@ -156,6 +156,12 @@ class NVCC_compiler(object):
os.path.join(os.path.split(__file__)[0], 'cuda_ndarray.cuh')) os.path.join(os.path.split(__file__)[0], 'cuda_ndarray.cuh'))
flags.append('-DCUDA_NDARRAY_CUH=' + cuda_ndarray_cuh_hash) flags.append('-DCUDA_NDARRAY_CUH=' + cuda_ndarray_cuh_hash)
# NumPy 1.7 Deprecate the old API. I updated most of the places
# to use the new API, but not everywhere. When finished, enable
# the following macro to assert that we don't bring new code
# that use the old API.
flags.append("-D NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION")
# numpy 1.7 deprecated the following macro but the didn't # numpy 1.7 deprecated the following macro but the didn't
# existed in the past # existed in the past
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]] numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论