提交 c33d5f99 authored 作者: Frederic Bastien's avatar Frederic Bastien

'New mechanism to create a CudaNdarray from a pycuda gpu memory and example/test for it.'

上级 63b17b09
"""
This file is an example of view the memory allocated by pycuda in a GpuArray
in a CudaNdarray to be able to use it in Theano.
This also serve as a test for the function: cuda_ndarray.from_gpu_pointer
"""
import sys
import numpy
import theano
import theano.sandbox.cuda as cuda_ndarray
import theano.misc.pycuda_init
if not theano.misc.pycuda_init.pycuda_available:
from nose.plugins.skip import SkipTest
import pdb;pdb.set_trace()
raise SkipTest("Pycuda not installed. Skip test of theano op with pycuda code.")
if cuda_ndarray.cuda_available == False:
from nose.plugins.skip import SkipTest
import pdb;pdb.set_trace()
raise SkipTest('Optional package cuda disabled')
import pycuda
import pycuda.driver as drv
import pycuda.gpuarray
def test_pycuda_simple():
x = cuda_ndarray.CudaNdarray.zeros((5,5))
from pycuda.compiler import SourceModule
mod = SourceModule("""
__global__ void multiply_them(float *dest, float *a, float *b)
{
const int i = threadIdx.x;
dest[i] = a[i] * b[i];
}
""")
multiply_them = mod.get_function("multiply_them")
a = numpy.random.randn(100).astype(numpy.float32)
b = numpy.random.randn(100).astype(numpy.float32)
dest = numpy.zeros_like(a)
multiply_them(
drv.Out(dest), drv.In(a), drv.In(b),
block=(400,1,1), grid=(1,1))
assert (dest==a*b).all()
def test_pycuda_memory_to_theano():
#Test that we can use the GpuArray memory space in pycuda in a CudaNdarray
y = pycuda.gpuarray.zeros((3,4,5), 'float32')
print numpy.asarray(y)
print "gpuarray ref count before creating a CudaNdarray", sys.getrefcount(y)
assert sys.getrefcount(y)==2
rand = numpy.random.randn(*y.shape).astype(numpy.float32)
cuda_rand = cuda_ndarray.CudaNdarray(rand)
strides = [1]
for i in y.shape[::-1][:-1]:
strides.append(strides[-1]*i)
strides = tuple(strides[::-1])
print 'strides', strides
assert cuda_rand._strides == strides, (cuda_rand._strides, strides)
z = cuda_ndarray.from_gpu_pointer(y.ptr, y.shape, strides, y)
print "gpuarray ref count after creating a CudaNdarray", sys.getrefcount(y)
assert sys.getrefcount(y)==3
assert (numpy.asarray(z) == 0).all()
cuda_ones = cuda_ndarray.CudaNdarray(numpy.asarray([[[1]]],dtype='float32'))
z += cuda_ones
assert (numpy.asarray(z) == numpy.ones(y.shape)).all()
assert (numpy.asarray(z) == 1).all()
assert cuda_rand.shape == z.shape
assert cuda_rand._strides == z._strides, (cuda_rand._strides, z._strides)
assert (numpy.asarray(cuda_rand) == rand).all()
z += cuda_rand
assert (numpy.asarray(z)==(rand+1)).all()
# Check that the ref count to the gpuarray is right.
del z
print "gpuarray ref count after deleting the CudaNdarray", sys.getrefcount(y)
assert sys.getrefcount(y)==2
......@@ -2010,6 +2010,100 @@ CudaNdarray_gpu_shutdown(PyObject* _unused, PyObject* _unused_args) {
return Py_None;
}
/*
* This function is tested in theano/misc/test_pycuda_theano_simple.py
*/
PyObject *
CudaNdarray_from_gpu_pointer(PyObject* _unused, PyObject* args)
{
PyObject *gpu_ptr = NULL;
PyObject *shapes = NULL;
PyObject *strides = NULL;
PyObject *base = NULL;
PyObject *rval = NULL;
//args should consist of 3 python objects
//The first is the gpu ptr
//The second if the shape
//The third if the strides
if (! PyArg_ParseTuple(args, "OOOO", &gpu_ptr, &shapes, &strides, &base))
return NULL;
printf("In CudaNdarray_from_gpu_pointer\n");
if (!PyLong_Check(gpu_ptr))
{
PyErr_Format(PyExc_Exception, "CudaNdarray_from_gpu_pointer: The gpu pointor is not an long");
return NULL;
}
Py_ssize_t nd = PyObject_Length(shapes);
if (nd < 0)
{
PyErr_SetString(PyExc_TypeError, "CudaNdarray_from_gpu_pointer: Couldn't get length of second argument");
return NULL;
}
Py_ssize_t nd_stride = PyObject_Length(strides);
if (nd_stride < 0)
{
PyErr_SetString(PyExc_TypeError, "CudaNdarray_from_gpu_pointer: Couldn't get length of third argument");
return NULL;
}
if (nd != nd_stride)
{
PyErr_SetString(PyExc_TypeError, "CudaNdarray_from_gpu_pointer: We need the same number of shapes and strides");
return NULL;
}
rval = CudaNdarray_new_null();
if (CudaNdarray_set_nd((CudaNdarray *)rval, nd))
{
//CudaNdarray_set_nd set the error msg
return NULL;
}
// set gpu pointeur
assert(((CudaNdarray *)rval)->data_allocated == 0);
if (CudaNdarray_set_device_data((CudaNdarray *)rval, (float *)PyInt_AsLong(gpu_ptr), base))
{
PyErr_SetString(PyExc_TypeError, "CudaNdarray_from_gpu_pointer: Error while setting the gpu pointor");
return NULL;
}
// Set dims and strides
for (int i = nd-1; i >= 0; --i)
{
PyObject * idx = PyLong_FromLong(i);
if (idx == NULL)
{
PyErr_SetString(PyExc_Exception, "CudaNdarray_from_gpu_pointer: Couldn't make long object to loop over list/tuple");
return NULL;
}
PyObject* dim_ = PyObject_GetItem(shapes, idx);
PyObject* strd_ = PyObject_GetItem(strides, idx);
if (!PyInt_Check(dim_))
{
PyErr_Format(PyExc_Exception, "CudaNdarray_from_gpu_pointer: shapes[%d] is not an int", i);
return NULL;
}
if (!PyInt_Check(strd_))
{
PyErr_Format(PyExc_Exception, "CudaNdarray_from_gpu_pointer: strides[%d] is not an int", i);
return NULL;
}
int dim = PyInt_AsLong(dim_);
int strd = PyInt_AsLong(strd_);
CudaNdarray_set_stride((CudaNdarray *)rval, i, strd);
CudaNdarray_set_dim((CudaNdarray *)rval, i, dim);
Py_DECREF(idx);
Py_DECREF(dim_);
Py_DECREF(strd_);
}
printf("CudaNdarray_from_gpu_pointer normal return\n");
return rval;
}
PyObject *
CudaNdarray_Dot(PyObject* _unused, PyObject* args)
{
......@@ -2175,6 +2269,7 @@ static PyMethodDef module_methods[] = {
{"ptr_int_size", CudaNdarray_ptr_int_size, METH_VARARGS, "Return a tuple with the size of gpu pointer, cpu pointer and int in bytes."},
{"filter", filter, METH_VARARGS, "filter(obj, broadcastable, strict, storage) returns a CudaNdarray initialized to obj if it matches the constraints of broadcastable. strict=True prevents any numeric casting. If storage is a CudaNdarray it may be overwritten and used as the return value."},
{"outstanding_mallocs", outstanding_mallocs, METH_VARARGS, "how many more mallocs have been called than free's"},
{"from_gpu_pointer", CudaNdarray_from_gpu_pointer, METH_VARARGS, "Used to create a CudaNdarray from already allocated memory on the gpu.(example by pycuda)"},
{NULL, NULL, NULL, NULL} /* Sentinel */
};
......@@ -2367,7 +2462,7 @@ CudaNdarray_new_nd(int nd)
return (PyObject *) rval;
}
int CudaNdarray_set_device_data(CudaNdarray * self, float * data, CudaNdarray * base)
int CudaNdarray_set_device_data(CudaNdarray * self, float * data, PyObject * base)
{
if (self->data_allocated)
{
......@@ -2380,10 +2475,10 @@ int CudaNdarray_set_device_data(CudaNdarray * self, float * data, CudaNdarray *
}
}
//N.B. XDECREF and XINCREF are no-ops for NULL pointers
if (self->base != (PyObject*)base)
if (self->base != base)
{
Py_XDECREF(self->base);
self->base = (PyObject*)base;
self->base = base;
Py_XINCREF(self->base);
}
self->data_allocated = 0;
......
......@@ -438,7 +438,11 @@ CudaNdarray_NewDims(int nd, const inttype * dims)
*
* Set self to be a view of given `data`, owned by existing CudaNdarray `base`.
*/
int CudaNdarray_set_device_data(CudaNdarray * self, float * data, CudaNdarray * base);
int CudaNdarray_set_device_data(CudaNdarray * self, float * data, PyObject * base);
int CudaNdarray_set_device_data(CudaNdarray * self, float * data, CudaNdarray * base)
{
return CudaNdarray_set_device_data(self, data, (PyObject *) base);
}
/**
* Return an independent copy of self
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论