提交 0a6269ce authored 作者: Frederic's avatar Frederic

Make CudaNdarray_setitem support setting a python int of value 0.

This call cudaMemset under the hood, needed for Scan on gpu.
上级 dd4f0c5c
......@@ -1674,19 +1674,44 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
// Doesn't handle broadcasting, e.g. a[:] = 5
// Can only be assigned from a CudaNdarray on the right side
// Or a ndarray when the left side part is c contiguous.
// Or a python scalar with value 0 when the left side part is c contiguous.
static int
CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
{
int verbose = 0;
if (verbose) fprintf(stderr, "CudaNdarray_setitem start\n");
// We try to copy directly into this CudaNdarray from the ndarray
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
if(!rval){
// CudaNdarray_Subscript failed and set the error msg.
Py_XDECREF(rval);
return -1;
}
if(rval != (CudaNdarray*)o &&
(rval->data_allocated ||
// The new array should have a base
!(((CudaNdarray*)rval)->base) ||
// If the original array has no base, the base of the new
// array should be the original one
(!((CudaNdarray*)o)->base && ((CudaNdarray*)rval)->base != o) ||
// Else, the two arrays should have the same base
(((CudaNdarray*)o)->base && ((CudaNdarray*)rval)->base != ((CudaNdarray*)o)->base)))
{
// This case shouldn't happen, based on what I see in Subscript
// but just in case it happens sometime in the future
PyErr_Format(PyExc_RuntimeError, "__getitem__ must return a CudaNdarray that refers to the original CudaNdarray, not a copy. rval.base=%p o.base=%p o=%p",
(((CudaNdarray*)rval)->base), ((CudaNdarray*)o)->base, o);
Py_DECREF(rval);
return -1;
}
PyObject * intobj = NULL;
if(CudaNdarray_Check(o) && PyArray_Check(value)){
// We try to copy directly into this CudaNdarray from the ndarray
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
if (verbose) fprintf(stderr, "CudaNdarray_setitem dest is a CudaNdarray and value is a ndarray\n");
int typenum = PyArray_TYPE(value);
if(!rval){
// CudaNdarray_Subscript failed and set the error msg.
Py_XDECREF(rval);
return -1;
}
if (typenum != REAL_TYPENUM){
PyErr_SetString(PyExc_TypeError, "CudaNdarray.__setitem__: can only copy from float32 arrays");
Py_XDECREF(rval);
......@@ -1698,7 +1723,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
return -1;
}
if(rval->nd != ((PyArrayObject*)value)->nd){
PyErr_Format(PyExc_NotImplementedError, "CudaNdarray.__setitem__: need same number of dims. destination nd=%d, source nd=%d. No broadcasting implemented.",
PyErr_Format(PyExc_NotImplementedError, "CudaNdarray.__setitem__: need same number of dims. destination nd=%d, source nd=%d. broadcasting implemented only for zeroing values from python scalar.",
rval->nd,((PyArrayObject*)value)->nd);
Py_XDECREF(rval);
return -1;
......@@ -1728,45 +1753,51 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
}
return 0;
}
if(!CudaNdarray_Check(o) || !CudaNdarray_Check(value))
else if ((intobj=PyNumber_Int(value)))
{
PyErr_SetString(PyExc_TypeError, "CudaNdarray.__setitem__: left must be a CudaNdarrays and right must be a CudaNdarrays or ndarray");
return -1;
if (verbose) fprintf(stderr, "CudaNdarray_setitem dest and value is a python number\n");
if(! CudaNdarray_is_c_contiguous(rval)){
PyErr_SetString(PyExc_NotImplementedError,
"CudaNdarray.__setitem__: When the new value is a scalar of value 0 the part where we copy to must be c contiguous.");
Py_XDECREF(rval);
return -1;
}
long val = PyInt_AsLong(intobj);
Py_DECREF(intobj); intobj=NULL;
if (val == 0)
{
cudaError_t err = cudaMemset(rval->devdata, 0, CudaNdarray_SIZE(rval) * sizeof(real));
Py_XDECREF(rval);
if (err)
{
PyErr_SetString(PyExc_RuntimeError,
"CudaNdarray.__setitem__: cudaMemset failed");
return -1;
}
return 0;
} else {
Py_XDECREF(rval);
PyErr_SetString(PyExc_NotImplementedError,
"CudaNdarray.__setitem__: we support setting only python scalar of value 0, numpy nd array and CudaNdarray.");
return -1;
}
}
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
PyErr_Clear(); // clear PyNumber_Int error.
if(rval == NULL)
{
// Actually error string was probably set if we get a NULL, so we leave it as it is
//PyErr_SetString(PyExc_RuntimeError, "__getitem__ returned an error");
return -1;
}
else if(rval != (CudaNdarray*)o &&
(rval->data_allocated ||
// The new array should have a base
!(((CudaNdarray*)rval)->base) ||
// If the original array has no base, the base of the new
// array should be the original one
(!((CudaNdarray*)o)->base && ((CudaNdarray*)rval)->base != o) ||
// Else, the two arrays should have the same base
(((CudaNdarray*)o)->base && ((CudaNdarray*)rval)->base != ((CudaNdarray*)o)->base)))
if(!CudaNdarray_Check(o) || !CudaNdarray_Check(value))
{
// This case shouldn't happen, based on what I see in Subscript
// but just in case it happens sometime in the future
PyErr_Format(PyExc_RuntimeError, "__getitem__ must return a CudaNdarray that refers to the original CudaNdarray, not a copy. rval.base=%p o.base=%p o=%p",
(((CudaNdarray*)rval)->base), ((CudaNdarray*)o)->base, o);
Py_DECREF(rval);
PyErr_SetString(PyExc_TypeError, "CudaNdarray.__setitem__: left must be a CudaNdarrays and right must be a CudaNdarrays, an ndarray or a python scalar of value 0.");
return -1;
}
if (verbose) fprintf(stderr, "CudaNdarray_setitem dest and value are CudaNdarray\n");
if (cnda_copy_structure_to_device(rval))
{
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray.__setitem__: syncing structure to device failed");
Py_DECREF(rval);
if (verbose) fprintf(stderr, "CudaNdarray_setitem error end\n");
return -1;
}
......@@ -1775,6 +1806,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)value, true))
{
Py_DECREF((PyObject*)rval);
if (verbose) fprintf(stderr, "CudaNdarray_setitem error end\n");
return -1;
}
......@@ -2547,6 +2579,7 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
Py_DECREF(py_src);
return 0;
}
bool
CudaNdarray_is_c_contiguous(const CudaNdarray * self)
{
......
......@@ -595,6 +595,10 @@ def test_setitem_matrixscalar0():
a[1,1] = theano._asarray(888, dtype='float32')
assert numpy.allclose(a,numpy.asarray(_a))
# broadcast a 0
_a[1, 1] = 0
_a[0:2] = 0
_a[1:] = 0
def test_setitem_matrixvector1():
a = theano._asarray([[0,1,2], [3,4,5]], dtype='float32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论