提交 8e512ce0 authored 作者: Frederic's avatar Frederic

Allow numpy.asarray(cuda_ndarray, dtype=...)

上级 dcd43ac9
......@@ -348,8 +348,38 @@ static PyMemberDef CudaNdarray_members[] =
{NULL} /* Sentinel */
};
PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self)
PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args)
{
PyObject * dtype = NULL;
if (! PyArg_ParseTuple(args, "|O", &dtype))
return NULL;
if (dtype) {
PyArray_Descr* dtype2;
// PyArray_DescrConverter try to convert anything to a PyArray_Descr.
if(!PyArray_DescrConverter(dtype, &dtype2))
{
PyObject * str = PyObject_Repr(dtype);
PyErr_Format(PyExc_TypeError,
"CudaNdarray dtype parameter not understood: %s",
PyString_AsString(str)
);
Py_CLEAR(str);
return NULL;
}
int typeNum = dtype2->type_num;
Py_DECREF(dtype2);
if (typeNum != NPY_FLOAT32)
{
PyObject * str = PyObject_Repr(dtype);
PyErr_Format(PyExc_TypeError,
"CudaNdarray support only support float32 dtype, provided: %d",
typeNum
);
Py_CLEAR(str);
return NULL;
}
}
int verbose = 0;
if(self->nd>=0 && CudaNdarray_SIZE(self)==0){
npy_intp * npydims = (npy_intp*)malloc(self->nd * sizeof(npy_intp));
......@@ -1217,7 +1247,7 @@ CudaNdarray_exp(CudaNdarray* self)
static PyMethodDef CudaNdarray_methods[] =
{
{"__array__",
(PyCFunction)CudaNdarray_CreateArrayObj, METH_NOARGS,
(PyCFunction)CudaNdarray_CreateArrayObj, METH_VARARGS,
"Copy from the device to a numpy ndarray"},
{"__copy__",
(PyCFunction)CudaNdarray_View, METH_NOARGS,
......
......@@ -473,7 +473,7 @@ DllExport int CudaNdarray_CopyFromCudaNdarray(CudaNdarray * self,
* Transfer the contents of CudaNdarray `self` to a new numpy ndarray.
*/
DllExport PyObject *
CudaNdarray_CreateArrayObj(CudaNdarray * self);
CudaNdarray_CreateArrayObj(CudaNdarray * self, PyObject *args = NULL);
DllExport PyObject *
CudaNdarray_ZEROS(int n, int * dims);
......
......@@ -38,6 +38,17 @@ def test_host_to_device():
c = numpy.asarray(b)
assert numpy.all(a == c)
# test with float32 dtype
d = numpy.asarray(b, dtype='float32')
assert numpy.all(a == d)
# test with not float32 dtype
try:
numpy.asarray(b, dtype='int8')
assert False
except TypeError:
pass
def test_add_iadd_idiv():
for shapes in (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论