提交 330a038c authored 作者: Frederic's avatar Frederic

CudaNdarray_setitem now accept more case when the new value is on the cpu.

上级 5fd2cb21
...@@ -1684,6 +1684,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) ...@@ -1684,6 +1684,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
if (verbose) fprintf(stderr, "CudaNdarray_setitem start\n"); if (verbose) fprintf(stderr, "CudaNdarray_setitem start\n");
// We try to copy directly into this CudaNdarray from the ndarray // We try to copy directly into this CudaNdarray from the ndarray
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key); CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
CudaNdarray* new_value = NULL;
if(!rval){ if(!rval){
// CudaNdarray_Subscript failed and set the error msg. // CudaNdarray_Subscript failed and set the error msg.
...@@ -1719,60 +1720,18 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) ...@@ -1719,60 +1720,18 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
fprintf(stderr, fprintf(stderr,
"CudaNdarray_setitem dest is a CudaNdarray and" "CudaNdarray_setitem dest is a CudaNdarray and"
" value is a ndarray\n"); " value is a ndarray\n");
int typenum = PyArray_TYPE(value); new_value = (CudaNdarray*) CudaNdarray_New();
if (typenum != REAL_TYPENUM){ if(!new_value)
PyErr_SetString(PyExc_TypeError, {
"CudaNdarray.__setitem__: can only copy from"
" float32 arrays");
Py_XDECREF(rval);
return -1;
}
if(! CudaNdarray_is_c_contiguous(rval)){
PyErr_SetString(PyExc_NotImplementedError,
"CudaNdarray.__setitem__: When the new value is"
" an ndarray the part where we copy it to must be"
" c contiguous.");
Py_XDECREF(rval);
return -1;
}
if(rval->nd != ((PyArrayObject*)value)->nd){
PyErr_Format(PyExc_NotImplementedError,
"CudaNdarray.__setitem__: need same number of dims."
" destination nd=%d, source nd=%d. broadcasting"
" implemented only for zeroing values from"
" python scalar.",
rval->nd,((PyArrayObject*)value)->nd);
Py_XDECREF(rval);
return -1; return -1;
} }
for(int i=0 ; i<rval->nd ; i++){ if(CudaNdarray_CopyFromArray(new_value, (PyArrayObject *) value))
if(CudaNdarray_HOST_DIMS(rval)[i] != ((PyArrayObject*)value)->dimensions[i]){ {
PyErr_Format(PyExc_ValueError, Py_XDECREF(new_value);
"CudaNdarray.__setitem__: need same dimensions for dim %d,"
" destination=%d, source=%ld",
i,
CudaNdarray_HOST_DIMS(rval)[i],
(long int)(((PyArrayObject*)value)->dimensions[i]));
Py_XDECREF(rval); Py_XDECREF(rval);
return -1; return -1;
}
}
PyArrayObject * py_v = (PyArrayObject*)PyArray_ContiguousFromAny(
(PyObject*)value, typenum,
rval->nd, rval->nd);
cublasSetVector(PyArray_SIZE(py_v),
sizeof(real),
PyArray_DATA(py_v), 1,
rval->devdata, 1);
CNDA_THREAD_SYNC;
Py_XDECREF(py_v);
Py_XDECREF(rval);
if (CUBLAS_STATUS_SUCCESS != cublasGetError()){
PyErr_SetString(PyExc_RuntimeError,
"CudaNdarray.__setitem__: error copying ndarray data to device memory");
return -1;
} }
return 0; value = (PyObject *) new_value;
} }
else if ((intobj=PyNumber_Int(value))) else if ((intobj=PyNumber_Int(value)))
{ {
...@@ -1817,6 +1776,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) ...@@ -1817,6 +1776,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"CudaNdarray.__setitem__: left must be a CudaNdarrays and right" "CudaNdarray.__setitem__: left must be a CudaNdarrays and right"
" must be a CudaNdarrays, an ndarray or a python scalar of value 0."); " must be a CudaNdarrays, an ndarray or a python scalar of value 0.");
Py_XDECREF(new_value);
return -1; return -1;
} }
...@@ -1828,6 +1788,8 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) ...@@ -1828,6 +1788,8 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
PyErr_SetString(PyExc_RuntimeError, PyErr_SetString(PyExc_RuntimeError,
"CudaNdarray.__setitem__: syncing structure to device failed"); "CudaNdarray.__setitem__: syncing structure to device failed");
Py_DECREF(rval); Py_DECREF(rval);
Py_XDECREF(new_value);
if (verbose) if (verbose)
fprintf(stderr, "CudaNdarray_setitem error end\n"); fprintf(stderr, "CudaNdarray_setitem error end\n");
return -1; return -1;
...@@ -1838,6 +1800,8 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) ...@@ -1838,6 +1800,8 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)value, true)) if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)value, true))
{ {
Py_DECREF((PyObject*)rval); Py_DECREF((PyObject*)rval);
Py_XDECREF(new_value);
if (verbose) if (verbose)
fprintf(stderr, "CudaNdarray_setitem error end\n"); fprintf(stderr, "CudaNdarray_setitem error end\n");
return -1; return -1;
...@@ -1848,6 +1812,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) ...@@ -1848,6 +1812,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
// Clean up locally-created references // Clean up locally-created references
Py_DECREF(rval); Py_DECREF(rval);
Py_XDECREF(new_value);
return 0; return 0;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论