提交 6550577b authored 作者: Frederic's avatar Frederic

make gpu reshape reuse gpu copy that is faster.

This also fix an not understood crash in gpu reshape code when the input is not contiguous.
上级 5001ae40
...@@ -567,38 +567,6 @@ PyObject * CudaNdarray_ReduceSum(CudaNdarray * self, PyObject * py_reduce_mask) ...@@ -567,38 +567,6 @@ PyObject * CudaNdarray_ReduceSum(CudaNdarray * self, PyObject * py_reduce_mask)
return (PyObject*)self_sum; return (PyObject*)self_sum;
} }
__global__ void k_copy_reshape_rowmajor(unsigned int numEls,
unsigned int a_nd, const float * a_data, const int * a_dim, const int * a_str,
unsigned int z_nd, float * z_data, const int * z_dim, const int * z_str)
{
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int numThreads = blockDim.x * gridDim.x;
for (unsigned int i = idx; i < numEls; i += numThreads)
{
const float * a_i = a_data;
unsigned int a_ii = i;
for (unsigned int _d = 0; _d < a_nd; ++_d) //make the rightmost coords change fastest
{
unsigned int d = a_nd - _d-1;
unsigned int a_i_d = a_ii % a_dim[d];
a_ii = a_ii / a_dim[d];
a_i += a_i_d * a_str[d];
}
unsigned int z_ii = i;
float * z_i = z_data;
for (unsigned int _d = 0; _d < z_nd; ++_d) //make the rightmost coords change fastest
{
unsigned int d = z_nd - _d-1;
//i tried to make the for loop count down, but it didn't work!?
unsigned int z_i_d = z_ii % z_dim[d];
z_i += z_i_d * z_str[d];
z_ii = z_ii / z_dim[d];
}
z_i[0] = a_i[0]; //copy one lousy float!
}
}
// Reshape self to the new shape gived by the tuple shape. // Reshape self to the new shape gived by the tuple shape.
// //
// If self is c contiguous, it return a view. Otherwise it always do a copy. // If self is c contiguous, it return a view. Otherwise it always do a copy.
...@@ -606,6 +574,22 @@ __global__ void k_copy_reshape_rowmajor(unsigned int numEls, ...@@ -606,6 +574,22 @@ __global__ void k_copy_reshape_rowmajor(unsigned int numEls,
// c contiguous // c contiguous
PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape) PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape)
{ {
if(!CudaNdarray_is_c_contiguous(self))
{
// allocate new space
//TODO: test to see if we can re-use old one and take a new param to
// use this
CudaNdarray* rval = (CudaNdarray*) CudaNdarray_Copy(self);
if (!rval)
{
return NULL;
}
CudaNdarray* ret = (CudaNdarray*) CudaNdarray_Reshape(rval, shape);
Py_XDECREF(rval);
return (PyObject*)ret;
}
// check shape tuple // check shape tuple
unsigned int rval_nd; unsigned int rval_nd;
unsigned int * rval_dims; unsigned int * rval_dims;
...@@ -656,9 +640,8 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape) ...@@ -656,9 +640,8 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape)
return rval; return rval;
} }
if(CudaNdarray_is_c_contiguous(self))
{
//return a view, not a copy //return a view, not a copy
//we can do this as we checked self is c_contiguous
CudaNdarray * rval = (CudaNdarray * )CudaNdarray_New(rval_nd); CudaNdarray * rval = (CudaNdarray * )CudaNdarray_New(rval_nd);
if (!rval || 0 != rval->data_allocated if (!rval || 0 != rval->data_allocated
...@@ -678,53 +661,8 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape) ...@@ -678,53 +661,8 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape)
} }
free(rval_dims); free(rval_dims);
return (PyObject*)rval; return (PyObject*)rval;
}
// allocate new space (TODO: test to see if we can re-use old one)
CudaNdarray * rval = (CudaNdarray * )CudaNdarray_New();
if (!rval || CudaNdarray_alloc_contiguous(rval, rval_nd, rval_dims)){
Py_XDECREF(rval);
free(rval_dims);
return NULL;
}
// call worker routine
unsigned int threads_per_block = std::min(rval_size, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
unsigned int n_blocks = std::min(ceil_intdiv(rval_size,threads_per_block), (unsigned int)NUM_VECTOR_OP_BLOCKS);
k_copy_reshape_rowmajor<<<n_blocks,threads_per_block>>>(
rval_size,
self->nd,
CudaNdarray_DEV_DATA(self), CudaNdarray_DEV_DIMS(self), CudaNdarray_DEV_STRIDES(self),
rval->nd,
CudaNdarray_DEV_DATA(rval), CudaNdarray_DEV_DIMS(rval), CudaNdarray_DEV_STRIDES(rval));
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
Py_DECREF(rval);
PyObject * shape_inp = CudaNdarray_get_shape(self, NULL);
PyObject * shape_inp2 = PyObject_Str(shape_inp);
PyObject * shape_dest = PyObject_Str(shape);
PyErr_Format(PyExc_RuntimeError,
"Cuda error in CudaNdarray_Reshape"
"()n_blocks=%d, n_threads=%d, input_shape=%s,"
" dest_shape=%s): %s: %s.\n",
n_blocks, threads_per_block,
PyString_AsString(shape_inp2),
PyString_AsString(shape_dest),
"k_copy_reshape_rowmajor",
cudaGetErrorString(err)
);
Py_DECREF(shape_dest);
Py_DECREF(shape_inp);
Py_DECREF(shape_inp2);
free(rval_dims);
return NULL;
}
free(rval_dims);
return (PyObject*)rval;
} }
PyObject * CudaNdarray_View(CudaNdarray * self) PyObject * CudaNdarray_View(CudaNdarray * self)
{ {
CudaNdarray * rval = (CudaNdarray*)CudaNdarray_New(self->nd); CudaNdarray * rval = (CudaNdarray*)CudaNdarray_New(self->nd);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论