提交 ca715940 authored 作者: fsavard's avatar fsavard

Added setitem for CudaNdarray based on suggestions from Fred. Can only assign…

Added setitem for CudaNdarray based on suggestions from Fred. Can only assign another CudaNdarray, and no support for broadcast in assignment, for the moment.
上级 4858a352
...@@ -1186,10 +1186,53 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -1186,10 +1186,53 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
return py_rval; return py_rval;
} }
// Will by called by __setitem__ in Python
// See http://docs.python.org/dev/py3k/c-api/object.html#PyObject_SetItem
// Doesn't handle broadcasting, e.g. a[:] = 5
// Can only be assigned from a CudaNdarray on the right side
static int
CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v)
{
if(!CudaNdarray_Check(o) || !CudaNdarray_Check(v))
{
PyErr_SetString(PyExc_TypeError, "both left and right of setitem must be CudaNdarrays");
return -1;
}
// Check that 'v' is compatible?
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
if(rval == NULL)
{
// Actually error string was probably set if we get a NULL, so we leave it as it is
//PyErr_SetString(PyExc_RuntimeError, "__getitem__ returned an error");
return -1;
}
else if((rval != (CudaNdarray*)o && rval->data_allocated) ||
(rval != (CudaNdarray*)o && rval->base != o))
{
// This case shouldn't happen, based on what I see in Subscript
// but just in case it happens sometime in the future
PyErr_SetString(PyExc_RuntimeError, "__getitem__ must return a CudaNdarray that refers to the original CudaNdarray, not a copy.");
Py_DECREF(rval);
return -1;
}
if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)v))
{
Py_DECREF(rval);
return -1;
}
// If it fails, deallocate memory (DECREF?)
return 0;
}
PyMappingMethods CudaNdarrayMappingMethods = { PyMappingMethods CudaNdarrayMappingMethods = {
CudaNdarray_len, //lenfunc mp_length; __len__ CudaNdarray_len, //lenfunc mp_length; __len__
CudaNdarray_Subscript, //binaryfunc mp_subscript; __getitem__ CudaNdarray_Subscript, //binaryfunc mp_subscript; __getitem__
0, //objobjargproc mp_ass_subscript; __setitem__ CudaNdarray_setitem //objobjargproc mp_ass_subscript; __setitem__
}; };
//////////////////// ////////////////////
......
...@@ -298,3 +298,96 @@ def test_gemm_vector_vector(): ...@@ -298,3 +298,96 @@ def test_gemm_vector_vector():
_c = cuda_ndarray.dot(_b,_a) _c = cuda_ndarray.dot(_b,_a)
assert _c.shape == (1,1) assert _c.shape == (1,1)
assert numpy.allclose(_c, numpy.dot(b, a)) assert numpy.allclose(_c, numpy.dot(b, a))
# ---------------------------------------------------------------------
def test_setitem_matrixvector1():
a = theano._asarray([[0,1,2], [3,4,5]], dtype='float32')
_a = cuda_ndarray.CudaNdarray(a)
b = theano._asarray([8,9], dtype='float32')
_b = cuda_ndarray.CudaNdarray(b)
# set second column to 8,9
_a[:,1] = _b
assert numpy.all(numpy.asarray(_a[:,1]) == b)
def test_setitem_matrix_tensor3():
a = numpy.arange(27)
a.resize((3,3,3))
a = theano._asarray(a, dtype='float32')
_a = cuda_ndarray.CudaNdarray(a)
b = theano._asarray([7,8,9], dtype='float32')
_b = cuda_ndarray.CudaNdarray(b)
# set middle row through cube to 7,8,9
_a[:,1,1] = _b
assert numpy.all(numpy.asarray(_a[:,1,1]) == b)
def test_setitem_assign_to_slice():
a = numpy.arange(27)
a.resize((3,3,3))
a = theano._asarray(a, dtype='float32')
_a = cuda_ndarray.CudaNdarray(a)
b = theano._asarray([7,8,9], dtype='float32')
_b = cuda_ndarray.CudaNdarray(b)
# first get a slice of a
_c = _a[:,:,1]
# set middle row through cube to 7,8,9
# (this corresponds to middle row of matrix _c)
_c[:,1] = _b
assert numpy.all(numpy.asarray(_a[:,1,1]) == b)
# this fails for the moment
def test_setitem_broadcast_must_fail():
a = numpy.arange(27)
a.resize((3,3,3))
a = theano._asarray(a, dtype='float32')
_a = cuda_ndarray.CudaNdarray(a)
b = theano._asarray([7,8,9], dtype='float32')
_b = cuda_ndarray.CudaNdarray(b)
try:
# attempt to assign vector to all rows of this submatrix
_a[:,:,1] = _b
assert False
except TypeError:
assert True
# this also fails for the moment
def test_setitem_rightvalue_ndarray_fails():
a = numpy.arange(27)
a.resize((3,3,3))
a = theano._asarray(a, dtype='float32')
_a = cuda_ndarray.CudaNdarray(a)
b = theano._asarray([7,8,9], dtype='float32')
_b = cuda_ndarray.CudaNdarray(b)
try:
# attempt to assign the ndarray b with setitem
_a[:,:,1] = b
assert False
except TypeError, e:
#print e
assert True
'''
if __name__ == '__main__':
test_setitem_matrixvector1()
test_setitem_matrix_tensor3()
test_setitem_broadcast_must_fail()
test_setitem_assign_to_slice()
test_setitem_rightvalue_ndarray_fails()
'''
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论