提交 3727eaeb authored 作者: Mathieu Germain's avatar Mathieu Germain 提交者: Marc-Alexandre Cote

First draft of the cuda cumsum.

上级 d826e4da
......@@ -24,9 +24,8 @@ class GpuCumsum(CumsumOp, GpuOp):
out_type = x.type()
if self.axis is None:
if self.axis is None and x.ndim > 1:
out_type = CudaNdarrayType(broadcastable=(False,), dtype=x.dtype)
return theano.Apply(self, [x], [out_type])
def make_thunk(self, node, storage_map, compute_map, no_recycling):
......@@ -56,55 +55,57 @@ class GpuCumsum(CumsumOp, GpuOp):
def c_support_code_apply(self, node, nodename):
axis = self.axis
return """
static __global__ void k_cumsum_1D_%(nodename)s(float* g_idata,
float* g_odata,
int n)
{
extern __shared__ float temp[2*blockDim.x];
int stride = 1;
__global__
void finalCumSum_1D_%(nodename)s(float * output, float * blockSum) {
int globalThreadID = (blockIdx.x + 1) * blockDim.x + threadIdx.x;
temp[2*threadIdx.x] = g_idata[2*threadIdx.x];
temp[2*threadIdx.x+1] = g_idata[2*threadIdx.x+1];
const float currentBlockSum = blockSum[blockIdx.x];
for (int d = n/2; d > 0; d /= 2)
{
__syncthreads();
if (threadIdx.x < d)
{
int ai = stride*(2*threadIdx.x+1)-1;
int bi = stride*(2*threadIdx.x+2)-1;
temp[bi] += temp[ai];
output[globalThreadID * 2] += currentBlockSum;
output[(globalThreadID * 2) + 1] += currentBlockSum;
}
stride *= 2;
}
if (threadIdx.x == 0) { temp[n - 1] = 0; } // NOt sure about that
__global__
void blockCumSum_1D_%(nodename)s(float * input, float * output, int numElements, float * blockSum) {
int globalThreadID = blockIdx.x * blockDim.x + threadIdx.x;
for (int d = 1; d < n; d *= 2)
{
__syncthreads();
if (globalThreadID < numElements/2) {
extern __shared__ float partialCumSum[];
// Load data in shared memory
partialCumSum[threadIdx.x*2] = input[globalThreadID*2];
partialCumSum[(threadIdx.x *2) +1] = input[(globalThreadID * 2) + 1];
if (threadIdx.x < d)
{
int ai = stride*(2*threadIdx.x+1)-1;
int bi = stride*(2*threadIdx.x+2)-1;
// Reduction Phase
for (int stride = 1; stride < blockDim.x*2; stride *= 2) {
__syncthreads();
int index = (threadIdx.x + 1) * (stride * 2) - 1;
if(index < blockDim.x*2) {
partialCumSum[index] += partialCumSum[index - stride];
}
}
float t = temp[ai];
temp[ai] = temp[bi];
temp[bi] += t;
// Reverse Phase
for (int stride = blockDim.x*2/2; stride > 0; stride /= 2) {
__syncthreads();
int index = (threadIdx.x + 1) * (stride * 2) - 1;
if(index + stride < blockDim.x*2) {
partialCumSum[index + stride] += partialCumSum[index];
}
}
// Wtite the final output to global memory
__syncthreads();
g_odata[2*threadIdx.x] = temp[2*threadIdx.x];
g_odata[2*threadIdx.x+1] = temp[2*threadIdx.x+1];
output[globalThreadID * 2] = partialCumSum[threadIdx.x * 2];
output[(globalThreadID * 2) + 1] = partialCumSum[(threadIdx.x * 2) + 1];
if (threadIdx.x == blockDim.x - 1) {
blockSum[blockIdx.x] = partialCumSum[(threadIdx.x * 2) + 1];
}
}
}
""" % locals()
def c_code(self, node, name, inames, onames, sub):
def c_code(self, node, nodename, inames, onames, sub):
x, = inames
z, = onames
axis = self.axis
......@@ -123,46 +124,56 @@ class GpuCumsum(CumsumOp, GpuOp):
code = """
npy_intp shape[1] = { CudaNdarray_SIZE(%(x)s) };
if(! (%(z)s && CudaNdarray_HOST_DIMS(%(z)s)[0] == shape[0]) )
{
if(! (%(z)s && CudaNdarray_HOST_DIMS(%(z)s)[0] == shape[0]) ) {
Py_XDECREF(%(z)s);
%(z)s = (CudaNdarray*) CudaNdarray_NewDims(1, shape);
}
if (!%(z)s)
if (!%(z)s) {
%(fail)s;
{
dim3 dim_block( min((int)shape[0], %(max_threads_dim0)s) );
dim3 dim_grid(1);
if (dim_block.x < shape[0])
dim_grid.x = (shape[0]-1 / dim_block.x) + 1; // Ceil
}
{ // Namespace for kernel calls //
int blockSize = min((int)shape[0], %(max_threads_dim0)s/2);
int dimGridX = ceil(shape[0] / (2.0*blockSize));
npy_intp WARDFRT[1] = { dimGridX };
CudaNdarray * deviceBlockSum = (CudaNdarray*) CudaNdarray_NewDims(1, WARDFRT);
void (*f)(float*, float*, int);
f = k_cumsum_1D_%(name)s;
dim3 dimBlock(blockSize, 1, 1);
dim3 dimGrid(dimGridX, 1, 1);
f<<<dim_grid,dim_block>>>(CudaNdarray_DEV_DATA(%(x)s),
blockCumSum_1D_%(nodename)s<<<dimGrid, dimBlock>>>
(
CudaNdarray_DEV_DATA(%(x)s),
CudaNdarray_DEV_DATA(%(z)s),
shape[0]);
CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error: %%s: %%s. (grid: %%i x %%i;"
" block: %%i x %%i x %%i; shared: %%i)\\n",
"k_cumsum_1D_%(name)s",
cudaGetErrorString(sts),
dim_grid.x,
dim_grid.y,
dim_block.x,
dim_block.y,
dim_block.z,
0);
%(fail)s;
shape[0],
CudaNdarray_DEV_DATA(deviceBlockSum)
);
if (dimGridX > 1) {
cudaThreadSynchronize();
dim3 dimGridBlockSum(1, 1, 1);
dim3 dimBlockBlockSum(dimGridX-1, 1, 1);
blockCumSum_1D_%(nodename)s<<<dimGridBlockSum, dimBlockBlockSum, (2*blockSize) * sizeof(float)>>>
(
CudaNdarray_DEV_DATA(deviceBlockSum),
CudaNdarray_DEV_DATA(deviceBlockSum),
dimGridX-1,
NULL
);
cudaThreadSynchronize();
dim3 dimGrid(dimGridX-1, 1, 1);
dim3 dimBlock(blockSize, 1, 1);
finalCumSum_1D_%(nodename)s<<<dimGrid, dimBlock>>>
(
CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_DEV_DATA(deviceBlockSum)
);
}
cudaDeviceSynchronize();
}
""" % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论