提交 330a038c authored 作者: Frederic's avatar Frederic

CudaNdarray_setitem now accept more case when the new value is on the cpu.

上级 5fd2cb21
......@@ -1684,6 +1684,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
if (verbose) fprintf(stderr, "CudaNdarray_setitem start\n");
// We try to copy directly into this CudaNdarray from the ndarray
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
CudaNdarray* new_value = NULL;
if(!rval){
// CudaNdarray_Subscript failed and set the error msg.
......@@ -1719,60 +1720,18 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
fprintf(stderr,
"CudaNdarray_setitem dest is a CudaNdarray and"
" value is a ndarray\n");
int typenum = PyArray_TYPE(value);
if (typenum != REAL_TYPENUM){
PyErr_SetString(PyExc_TypeError,
"CudaNdarray.__setitem__: can only copy from"
" float32 arrays");
Py_XDECREF(rval);
return -1;
}
if(! CudaNdarray_is_c_contiguous(rval)){
PyErr_SetString(PyExc_NotImplementedError,
"CudaNdarray.__setitem__: When the new value is"
" an ndarray the part where we copy it to must be"
" c contiguous.");
Py_XDECREF(rval);
return -1;
}
if(rval->nd != ((PyArrayObject*)value)->nd){
PyErr_Format(PyExc_NotImplementedError,
"CudaNdarray.__setitem__: need same number of dims."
" destination nd=%d, source nd=%d. broadcasting"
" implemented only for zeroing values from"
" python scalar.",
rval->nd,((PyArrayObject*)value)->nd);
Py_XDECREF(rval);
new_value = (CudaNdarray*) CudaNdarray_New();
if(!new_value)
{
return -1;
}
for(int i=0 ; i<rval->nd ; i++){
if(CudaNdarray_HOST_DIMS(rval)[i] != ((PyArrayObject*)value)->dimensions[i]){
PyErr_Format(PyExc_ValueError,
"CudaNdarray.__setitem__: need same dimensions for dim %d,"
" destination=%d, source=%ld",
i,
CudaNdarray_HOST_DIMS(rval)[i],
(long int)(((PyArrayObject*)value)->dimensions[i]));
if(CudaNdarray_CopyFromArray(new_value, (PyArrayObject *) value))
{
Py_XDECREF(new_value);
Py_XDECREF(rval);
return -1;
}
}
PyArrayObject * py_v = (PyArrayObject*)PyArray_ContiguousFromAny(
(PyObject*)value, typenum,
rval->nd, rval->nd);
cublasSetVector(PyArray_SIZE(py_v),
sizeof(real),
PyArray_DATA(py_v), 1,
rval->devdata, 1);
CNDA_THREAD_SYNC;
Py_XDECREF(py_v);
Py_XDECREF(rval);
if (CUBLAS_STATUS_SUCCESS != cublasGetError()){
PyErr_SetString(PyExc_RuntimeError,
"CudaNdarray.__setitem__: error copying ndarray data to device memory");
return -1;
}
return 0;
value = (PyObject *) new_value;
}
else if ((intobj=PyNumber_Int(value)))
{
......@@ -1817,6 +1776,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
PyErr_SetString(PyExc_TypeError,
"CudaNdarray.__setitem__: left must be a CudaNdarrays and right"
" must be a CudaNdarrays, an ndarray or a python scalar of value 0.");
Py_XDECREF(new_value);
return -1;
}
......@@ -1828,6 +1788,8 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
PyErr_SetString(PyExc_RuntimeError,
"CudaNdarray.__setitem__: syncing structure to device failed");
Py_DECREF(rval);
Py_XDECREF(new_value);
if (verbose)
fprintf(stderr, "CudaNdarray_setitem error end\n");
return -1;
......@@ -1838,6 +1800,8 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)value, true))
{
Py_DECREF((PyObject*)rval);
Py_XDECREF(new_value);
if (verbose)
fprintf(stderr, "CudaNdarray_setitem error end\n");
return -1;
......@@ -1848,6 +1812,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
// Clean up locally-created references
Py_DECREF(rval);
Py_XDECREF(new_value);
return 0;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论