提交 e87df287 authored 作者: lamblin's avatar lamblin

Merge pull request #338 from nouiz/cudandarray_setitem

Cudandarray setitem
...@@ -1674,19 +1674,44 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -1674,19 +1674,44 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
// Doesn't handle broadcasting, e.g. a[:] = 5 // Doesn't handle broadcasting, e.g. a[:] = 5
// Can only be assigned from a CudaNdarray on the right side // Can only be assigned from a CudaNdarray on the right side
// Or a ndarray when the left side part is c contiguous. // Or a ndarray when the left side part is c contiguous.
// Or a python scalar with value 0 when the left side part is c contiguous.
static int static int
CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
{ {
int verbose = 0;
if (verbose) fprintf(stderr, "CudaNdarray_setitem start\n");
// We try to copy directly into this CudaNdarray from the ndarray
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
if(!rval){
// CudaNdarray_Subscript failed and set the error msg.
Py_XDECREF(rval);
return -1;
}
if(rval != (CudaNdarray*)o &&
(rval->data_allocated ||
// The new array should have a base
!(((CudaNdarray*)rval)->base) ||
// If the original array has no base, the base of the new
// array should be the original one
(!((CudaNdarray*)o)->base && ((CudaNdarray*)rval)->base != o) ||
// Else, the two arrays should have the same base
(((CudaNdarray*)o)->base && ((CudaNdarray*)rval)->base != ((CudaNdarray*)o)->base)))
{
// This case shouldn't happen, based on what I see in Subscript
// but just in case it happens sometime in the future
PyErr_Format(PyExc_RuntimeError, "__getitem__ must return a CudaNdarray that refers to the original CudaNdarray, not a copy. rval.base=%p o.base=%p o=%p",
(((CudaNdarray*)rval)->base), ((CudaNdarray*)o)->base, o);
Py_DECREF(rval);
return -1;
}
PyObject * intobj = NULL;
if(CudaNdarray_Check(o) && PyArray_Check(value)){ if(CudaNdarray_Check(o) && PyArray_Check(value)){
// We try to copy directly into this CudaNdarray from the ndarray if (verbose) fprintf(stderr, "CudaNdarray_setitem dest is a CudaNdarray and value is a ndarray\n");
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
int typenum = PyArray_TYPE(value); int typenum = PyArray_TYPE(value);
if(!rval){
// CudaNdarray_Subscript failed and set the error msg.
Py_XDECREF(rval);
return -1;
}
if (typenum != REAL_TYPENUM){ if (typenum != REAL_TYPENUM){
PyErr_SetString(PyExc_TypeError, "CudaNdarray.__setitem__: can only copy from float32 arrays"); PyErr_SetString(PyExc_TypeError, "CudaNdarray.__setitem__: can only copy from float32 arrays");
Py_XDECREF(rval); Py_XDECREF(rval);
...@@ -1698,7 +1723,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) ...@@ -1698,7 +1723,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
return -1; return -1;
} }
if(rval->nd != ((PyArrayObject*)value)->nd){ if(rval->nd != ((PyArrayObject*)value)->nd){
PyErr_Format(PyExc_NotImplementedError, "CudaNdarray.__setitem__: need same number of dims. destination nd=%d, source nd=%d. No broadcasting implemented.", PyErr_Format(PyExc_NotImplementedError, "CudaNdarray.__setitem__: need same number of dims. destination nd=%d, source nd=%d. broadcasting implemented only for zeroing values from python scalar.",
rval->nd,((PyArrayObject*)value)->nd); rval->nd,((PyArrayObject*)value)->nd);
Py_XDECREF(rval); Py_XDECREF(rval);
return -1; return -1;
...@@ -1728,45 +1753,51 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) ...@@ -1728,45 +1753,51 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
} }
return 0; return 0;
} }
else if ((intobj=PyNumber_Int(value)))
if(!CudaNdarray_Check(o) || !CudaNdarray_Check(value))
{ {
PyErr_SetString(PyExc_TypeError, "CudaNdarray.__setitem__: left must be a CudaNdarrays and right must be a CudaNdarrays or ndarray"); if (verbose) fprintf(stderr, "CudaNdarray_setitem dest and value is a python number\n");
return -1; if(! CudaNdarray_is_c_contiguous(rval)){
PyErr_SetString(PyExc_NotImplementedError,
"CudaNdarray.__setitem__: When the new value is a scalar of value 0 the part where we copy to must be c contiguous.");
Py_XDECREF(rval);
return -1;
}
long val = PyInt_AsLong(intobj);
Py_DECREF(intobj); intobj=NULL;
if (val == 0)
{
cudaError_t err = cudaMemset(rval->devdata, 0, CudaNdarray_SIZE(rval) * sizeof(real));
Py_XDECREF(rval);
if (err)
{
PyErr_SetString(PyExc_RuntimeError,
"CudaNdarray.__setitem__: cudaMemset failed");
return -1;
}
return 0;
} else {
Py_XDECREF(rval);
PyErr_SetString(PyExc_NotImplementedError,
"CudaNdarray.__setitem__: we support setting only python scalar of value 0, numpy nd array and CudaNdarray.");
return -1;
}
} }
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key); PyErr_Clear(); // clear PyNumber_Int error.
if(rval == NULL) if(!CudaNdarray_Check(o) || !CudaNdarray_Check(value))
{
// Actually error string was probably set if we get a NULL, so we leave it as it is
//PyErr_SetString(PyExc_RuntimeError, "__getitem__ returned an error");
return -1;
}
else if(rval != (CudaNdarray*)o &&
(rval->data_allocated ||
// The new array should have a base
!(((CudaNdarray*)rval)->base) ||
// If the original array has no base, the base of the new
// array should be the original one
(!((CudaNdarray*)o)->base && ((CudaNdarray*)rval)->base != o) ||
// Else, the two arrays should have the same base
(((CudaNdarray*)o)->base && ((CudaNdarray*)rval)->base != ((CudaNdarray*)o)->base)))
{ {
// This case shouldn't happen, based on what I see in Subscript PyErr_SetString(PyExc_TypeError, "CudaNdarray.__setitem__: left must be a CudaNdarrays and right must be a CudaNdarrays, an ndarray or a python scalar of value 0.");
// but just in case it happens sometime in the future
PyErr_Format(PyExc_RuntimeError, "__getitem__ must return a CudaNdarray that refers to the original CudaNdarray, not a copy. rval.base=%p o.base=%p o=%p",
(((CudaNdarray*)rval)->base), ((CudaNdarray*)o)->base, o);
Py_DECREF(rval);
return -1; return -1;
} }
if (verbose) fprintf(stderr, "CudaNdarray_setitem dest and value are CudaNdarray\n");
if (cnda_copy_structure_to_device(rval)) if (cnda_copy_structure_to_device(rval))
{ {
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray.__setitem__: syncing structure to device failed"); PyErr_SetString(PyExc_RuntimeError, "CudaNdarray.__setitem__: syncing structure to device failed");
Py_DECREF(rval); Py_DECREF(rval);
if (verbose) fprintf(stderr, "CudaNdarray_setitem error end\n");
return -1; return -1;
} }
...@@ -1775,6 +1806,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value) ...@@ -1775,6 +1806,7 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *value)
if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)value, true)) if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)value, true))
{ {
Py_DECREF((PyObject*)rval); Py_DECREF((PyObject*)rval);
if (verbose) fprintf(stderr, "CudaNdarray_setitem error end\n");
return -1; return -1;
} }
...@@ -2547,6 +2579,7 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj) ...@@ -2547,6 +2579,7 @@ CudaNdarray_CopyFromArray(CudaNdarray * self, PyArrayObject*obj)
Py_DECREF(py_src); Py_DECREF(py_src);
return 0; return 0;
} }
bool bool
CudaNdarray_is_c_contiguous(const CudaNdarray * self) CudaNdarray_is_c_contiguous(const CudaNdarray * self)
{ {
......
...@@ -35,7 +35,7 @@ def tes_use(): ...@@ -35,7 +35,7 @@ def tes_use():
def test_sum(): def test_sum():
""" """
test sum pattern 1, 11, 10, 01, 100, 110, 011, 001, 111, 0011, 0101, 0111, 1011, 1111 test sum pattern 1, 11, 10, 01, 001, 010, 100, 110, 011, 111, 0011, 0101, 0111, 1011, 1111
test sum pattern implemented with reshape: test sum pattern implemented with reshape:
1000, 0100, 0010, 0001, 11111 1000, 0100, 0010, 0001, 11111
......
...@@ -595,6 +595,10 @@ def test_setitem_matrixscalar0(): ...@@ -595,6 +595,10 @@ def test_setitem_matrixscalar0():
a[1,1] = theano._asarray(888, dtype='float32') a[1,1] = theano._asarray(888, dtype='float32')
assert numpy.allclose(a,numpy.asarray(_a)) assert numpy.allclose(a,numpy.asarray(_a))
# broadcast a 0
_a[1, 1] = 0
_a[0:2] = 0
_a[1:] = 0
def test_setitem_matrixvector1(): def test_setitem_matrixvector1():
a = theano._asarray([[0,1,2], [3,4,5]], dtype='float32') a = theano._asarray([[0,1,2], [3,4,5]], dtype='float32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论