提交 8e512ce0 authored 作者: Frederic's avatar Frederic

Allow numpy.asarray(cuda_ndarray, dtype=...)

上级 dcd43ac9
...@@ -348,8 +348,38 @@ static PyMemberDef CudaNdarray_members[] = ...@@ -348,8 +348,38 @@ static PyMemberDef CudaNdarray_members[] =
{NULL} /* Sentinel */ {NULL} /* Sentinel */
}; };
PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self) PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args)
{ {
PyObject * dtype = NULL;
if (! PyArg_ParseTuple(args, "|O", &dtype))
return NULL;
if (dtype) {
PyArray_Descr* dtype2;
// PyArray_DescrConverter try to convert anything to a PyArray_Descr.
if(!PyArray_DescrConverter(dtype, &dtype2))
{
PyObject * str = PyObject_Repr(dtype);
PyErr_Format(PyExc_TypeError,
"CudaNdarray dtype parameter not understood: %s",
PyString_AsString(str)
);
Py_CLEAR(str);
return NULL;
}
int typeNum = dtype2->type_num;
Py_DECREF(dtype2);
if (typeNum != NPY_FLOAT32)
{
PyObject * str = PyObject_Repr(dtype);
PyErr_Format(PyExc_TypeError,
"CudaNdarray support only support float32 dtype, provided: %d",
typeNum
);
Py_CLEAR(str);
return NULL;
}
}
int verbose = 0; int verbose = 0;
if(self->nd>=0 && CudaNdarray_SIZE(self)==0){ if(self->nd>=0 && CudaNdarray_SIZE(self)==0){
npy_intp * npydims = (npy_intp*)malloc(self->nd * sizeof(npy_intp)); npy_intp * npydims = (npy_intp*)malloc(self->nd * sizeof(npy_intp));
...@@ -1217,7 +1247,7 @@ CudaNdarray_exp(CudaNdarray* self) ...@@ -1217,7 +1247,7 @@ CudaNdarray_exp(CudaNdarray* self)
static PyMethodDef CudaNdarray_methods[] = static PyMethodDef CudaNdarray_methods[] =
{ {
{"__array__", {"__array__",
(PyCFunction)CudaNdarray_CreateArrayObj, METH_NOARGS, (PyCFunction)CudaNdarray_CreateArrayObj, METH_VARARGS,
"Copy from the device to a numpy ndarray"}, "Copy from the device to a numpy ndarray"},
{"__copy__", {"__copy__",
(PyCFunction)CudaNdarray_View, METH_NOARGS, (PyCFunction)CudaNdarray_View, METH_NOARGS,
......
...@@ -473,7 +473,7 @@ DllExport int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self, ...@@ -473,7 +473,7 @@ DllExport int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
* Transfer the contents of CudaNdarray `self` to a new numpy ndarray. * Transfer the contents of CudaNdarray `self` to a new numpy ndarray.
*/ */
DllExport PyObject * DllExport PyObject *
CudaNdarray_CreateArrayObj(CudaNdarray * self); CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args = NULL);
DllExport PyObject * DllExport PyObject *
CudaNdarray_ZEROS(int n, int * dims); CudaNdarray_ZEROS(int n, int * dims);
......
...@@ -38,6 +38,17 @@ def test_host_to_device(): ...@@ -38,6 +38,17 @@ def test_host_to_device():
c = numpy.asarray(b) c = numpy.asarray(b)
assert numpy.all(a == c) assert numpy.all(a == c)
# test with float32 dtype
d = numpy.asarray(b, dtype='float32')
assert numpy.all(a == d)
# test with not float32 dtype
try:
numpy.asarray(b, dtype='int8')
assert False
except TypeError:
pass
def test_add_iadd_idiv(): def test_add_iadd_idiv():
for shapes in ( for shapes in (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论